From 4448494369ed0f9f6d5578d7844c4df053ff4f8e Mon Sep 17 00:00:00 2001 From: FloRul Date: Tue, 13 Aug 2024 12:05:06 -0400 Subject: [PATCH] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 4c8c8813241947c8f41953a4407056eb4b761cf1 Author: Amna Mubashar Date: Tue Aug 13 17:58:57 2024 +0200 feat: make truncation optional for bedrock chat generator (#967) * Added truncate param to chat generator and adapters * Added tests to check truncation * Add doc_string * Fixed linting commit 0451e6f43ad5731887ab0aa2aa8e4de9020913e6 Author: HaystackBot Date: Mon Aug 12 15:56:05 2024 +0000 Update the changelog commit 0f1452ac3d7fcec1142473f5049cd0266b52784a Author: David S. Batista Date: Mon Aug 12 17:46:03 2024 +0200 refactor: change meta data fields (#911) * initial import * formatting * fixing tests * removing warnings * linting issues * fixes due to conflicts commit a8b2de9d86621aba2243e2d6e350e082c087700d Author: Stefano Fiorucci Date: Mon Aug 12 17:19:02 2024 +0200 test: do not retry tests in `hatch run test` command (#954) * do not retry tests in hatch run test command * fix * hatch config improvements commit 93d2c6824207f0e29928fc82516b2327ff0d54d2 Author: HaystackBot Date: Mon Aug 12 13:41:07 2024 +0000 Update the changelog commit 7d90a58f1e77e776d5682ccbb50c87d71eee99c4 Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Mon Aug 12 15:36:58 2024 +0200 fix: support streaming_callback param in amazon bedrock generators (#927) * fix: support streaming_callback param in amazon bedrock generators * fix chat generator merge * reformat --------- Co-authored-by: Thomas Stadelmann commit f03073f995fa39a88f8ca4b14438ee6a6aa3c892 Author: Rashmi Pawar <168514198+raspawar@users.noreply.github.com> Date: Mon Aug 12 14:30:16 2024 +0530 Add default model for NVIDIA HayStack local NIM endpoints (#915) * initial embedder code * default model code * docs: update model docstring * tests: add userwarning * docs: literal lint fix * review changes * remove pydantic dependency * move backend, nim_backend under utils * move is_hosted to warm_up * test cases, docstring fix * error message updation Co-authored-by: Madeesh Kannan * move is_hosted code to util * remove backend code * update import for is_hosted * remove util and move code to utils * fix api key issue for failing test cases * Update integrations/nvidia/tests/conftest.py --------- Co-authored-by: Madeesh Kannan commit 7f5b12e9260d97afbdf1ba550bc7ee450e9a0bfc Author: HaystackBot Date: Thu Aug 8 13:34:35 2024 +0000 Update the changelog commit 76a35a7240228d1abaeb0dd50f7ccfe250bb2149 Author: Stefano Fiorucci Date: Thu Aug 8 15:33:01 2024 +0200 chore: pin `llama-cpp-python>=0.2.87` (#955) * pin llama-cpp-python>=0.2.86 * update version commit 9a7a9f748a804fe08387bb0f67630611720ef319 Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Thu Aug 8 12:46:35 2024 +0200 Docs: Update AmazonBedrockGenerator docstrings (#956) * update docstrings * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit ee08a763c856411c75ca35a47cfa36ce47c9291d Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Thu Aug 8 12:46:17 2024 +0200 Docs: Update CohereGenerator docstrings (#960) * Update docstrings * Update integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py Co-authored-by: Daria Fokina * Update integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit dccaf3f3e2f5c8eb270442f2cb5eb69f7e490bbc Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Thu Aug 8 12:45:50 2024 +0200 Docs: Update CohereChatGenerator docstrings (#958) * update docstrings * Fix formatting * fix formatting * Update integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py Co-authored-by: Daria Fokina * Update integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py Co-authored-by: Daria Fokina * Update integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit b8b72ae2bcaa3e61078df6e377a0491763f632f2 Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Thu Aug 8 12:45:23 2024 +0200 Docs: Update GoogleChatGenerator docstrings (#962) * Update docstrings * Update integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit 855dc33a031fe523c2767bc10e99c8d1a76b639a Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Thu Aug 8 12:45:01 2024 +0200 Docs: Update GeminiGenerator docstrings (#964) * Update docstrings * Update integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit fe9292bfc07e6d9d005572ae17aed169fb0568e6 Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Thu Aug 8 12:44:22 2024 +0200 Docs: Update NvidiaGenerator docstrings (#966) * Update docstrings * Fix formatting * Update integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py Co-authored-by: Daria Fokina * Update integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit f6c4b242071dbdc4a08c70b71f3ed88e72e40434 Author: Mateusz Haligowski Date: Tue Aug 6 13:31:02 2024 +0200 feat: remove gradient integration (#926) commit e664b0cad1f9db8cc2f1abedb6b192c66ec85e47 Author: Vladimir Blagojevic Date: Tue Aug 6 10:41:22 2024 +0100 Update Langfuse README to avoid common initialization issues (#952) commit 4f15df0b257b44f8ba31409b6e39117009665be6 Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Tue Aug 6 09:49:28 2024 +0200 Docs: Update AmazonBedrockChatGenerator docstrings (#949) * UPdate docstrings * Fix formatting * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina commit 62d643bcd42b7cabad2eb9908dd68c6ea70a7c19 Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Tue Aug 6 09:49:07 2024 +0200 Docs: Update BM25 docstrings (#945) * update docstrings * add description * fix linters * fix whitespaces commit 5be0bf74d2836b81cc60de84d9f8a1541ec7e191 Author: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> Date: Mon Aug 5 22:01:52 2024 +0200 Update docstrings (#947) commit 993d99e0a645d67c5a1813218224c483406ca518 Author: HaystackBot Date: Mon Aug 5 18:16:57 2024 +0000 Update the changelog commit a665f1f65be1f169f21bc9d3484051f1ad1c3636 Author: Stefano Fiorucci Date: Fri Aug 2 16:47:20 2024 +0200 introduce utility function (#939) Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> commit 7cee6c80c3ec532112ca4b7fb976711d48d72b4e Author: Amna Mubashar Date: Fri Aug 2 15:37:58 2024 +0200 fix: replace DynamicChatPromptBuilder with ChatPromptBuilder (#940) * Remove occurrences of DynamicChatPromptBuilder commit 3d698fbe9bf4115db27acf095dadcd12ce40e35b Author: Amna Mubashar Date: Fri Aug 2 15:21:44 2024 +0200 Pin llama-cpp version (#943) Co-authored-by: Amna Mubashar commit 7fbc7062f1205a4f65ca59956f4f2f1bb5a37423 Author: Stefano Fiorucci Date: Fri Aug 2 11:44:10 2024 +0200 update docker compose to v2 (#941) commit fa4c6ccfa7e471baf01045d9e48b118594a92b50 Author: Amna Mubashar Date: Thu Aug 1 18:31:21 2024 +0200 Small improvement for resolving connection string (#937) commit eab212ab5b79d782f1a840a2078ff11651ac0d8d Author: Stefano Lottini Date: Thu Aug 1 11:38:56 2024 +0200 fix: Astra DB, improved warnings and guidance about indexing-related mismatches (#932) * better warning text and readme note on indexing settings * language * style * style commit f3878d481558d31b52c4fdebdb8c63c912d58324 Author: HaystackBot Date: Thu Aug 1 08:05:02 2024 +0000 Update the changelog commit 35d657a06eea311171da3265cbf1200d5f5bc971 Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Thu Aug 1 10:03:39 2024 +0200 feat: support aws authentication with OpenSearchDocumentStore (#920) * feat: support aws auth with OpenSearchDocumentStore * fix: to_dict() from_dict() * fix tests * add tests * fix lint * fix mypy * rename test class * fix feedback * lazy-import boto3 * move _get_auth() to AWSAuth class * get rid of aws_auth param * better docstrings * apply review feedback commit f0b619e890cb07993c14dc94ee4c4b12441b948f Author: Vladimir Blagojevic Date: Wed Jul 31 07:57:02 2024 +0100 chore: Remove all `DynamicChatPromptBuilder` references in Langfuse integration (#931) * Remove all DynamicChatPromptBuilder references * Lint fixes commit b8e2623014be9ab7410d98832e25eda46923b04c Author: Vladimir Blagojevic Date: Tue Jul 30 09:04:39 2024 +0100 chore: `Langfuse` - replace DynamicChatPromptBuilder with ChatPromptBuilder (#925) commit 2f6f134300c4e8743dc567f6a5e08ff48962af6c Author: Rashmi Pawar <168514198+raspawar@users.noreply.github.com> Date: Mon Jul 29 21:21:32 2024 +0530 Raise warning for base_url ../embeddings .../completions .../rankings (#922) * add validation for base url routes * move url validation to utils * update docstring for url validation * add typing for arg type * return docstring update Co-authored-by: Madeesh Kannan * fix typo error Co-authored-by: Madeesh Kannan --------- Co-authored-by: Madeesh Kannan commit 0cdda5ce6c5e92ef77dcfea046bc7af0bfb47aff Author: Tobias Wochinger Date: Wed Jul 24 18:27:45 2024 +0200 docs: add release instructions (#923) * docs: add release instructions * docs: add note about maintainers only commit 84aa8356e067f2ed4a59949f8254262f5268d347 Author: HaystackBot Date: Wed Jul 24 16:08:33 2024 +0000 Update the changelog commit 9d511de96ddf2f9f204be00053bb7be9855ecf3a Author: HaystackBot Date: Wed Jul 24 13:11:25 2024 +0000 Update the changelog commit cd521cf00e2b8fd8eed3d7048505e2ed3c3f3a92 Author: Stefano Fiorucci Date: Wed Jul 24 12:32:05 2024 +0200 chore: Ragas - remove context relevancy metric (#917) * ragas: remove context relevancy * try removing rerun-failures * add rerun-failures back, introduce pytest-asycio * add asyncio marker * lower-bound pin commit 282ccc4a41f9dc9791c1f811439572d1563d0238 Author: Vladimir Blagojevic Date: Wed Jul 24 10:45:44 2024 +0200 Use collections.list_all instead of collections._get_all (#921) commit eaf36aa25baee410341f9153784762962a27c164 Author: HaystackBot Date: Tue Jul 23 15:30:06 2024 +0000 Update the changelog commit 2ed4be7cf9a84bfd6a926adadec3e05d5a8de979 Author: Anthony Tran Date: Mon Jul 22 10:26:10 2024 -0400 Fix nested logic operators * Normalize logical filter conditions commit a30cb9620f5744e793968a6085696c57c06af7ec Author: Stefano Fiorucci Date: Mon Jul 22 10:02:36 2024 +0200 better compatibility (#914) commit 6fb8a66b6a874afaa90dbbef874216a56782456e Author: HaystackBot Date: Wed Jul 17 15:26:49 2024 +0000 Update the changelog commit 3d07b790f8d3618c56f09488c828e72463c7802b Author: HaystackBot Date: Wed Jul 17 15:22:09 2024 +0000 Update the changelog commit 93eae15351ecd88105933b83b1a217d39390ce36 Author: David S. Batista Date: Wed Jul 17 17:13:29 2024 +0200 Add meta deprecration warning (#910) * adding deprecation warnings * fixing imports * fixing 3.9 linting issues commit f03f8bc2b9c4110ec29c06528601ae3e255783b2 Author: HaystackBot Date: Wed Jul 17 13:09:06 2024 +0000 Update the changelog commit 4553a05336ba19909747ab053cd5316350c0027a Author: HaystackBot Date: Wed Jul 17 09:21:17 2024 +0000 Update the changelog commit 9893b56d553a158264894440edefa9661589c22c Author: Stefano Fiorucci Date: Wed Jul 17 11:17:36 2024 +0200 fix: `ChromaDocumentStore` - discard `meta` items when the type of their value is not supported in Chroma (#907) * discard invalid meta values * reduce warnings commit db2b5f72bd1c724eb346a3e4b55c3b416a1fd821 Author: Vladimir Blagojevic Date: Wed Jul 17 11:04:57 2024 +0200 Add defensive check for filter_policy deserialization (#903) commit be04358aefd6f125c38634acb44cb12300514664 Author: HaystackBot Date: Wed Jul 17 07:28:22 2024 +0000 Update the changelog commit f5c93d90ceea9bb3dacc2fe91507cd07e0e5cfb1 Author: HaystackBot Date: Wed Jul 17 07:25:46 2024 +0000 Update the changelog commit ac92b3a832f2c2ea352887c03d02b07e7bc0ba3a Author: HaystackBot Date: Wed Jul 17 07:18:41 2024 +0000 Update the changelog commit 80622966951cc2897a689805de203f866c955819 Author: Vladimir Blagojevic Date: Wed Jul 17 09:14:45 2024 +0200 fix: `PgVector` - Fallback to default filter policy when deserializing retrievers without the init parameter (#900) * Add defensive check for filter_policy deserialization * Update integrations/pgvector/tests/test_retrievers.py Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista commit d943f4e62c8950c11438bb99f6314a83d02d2b53 Author: Vladimir Blagojevic Date: Wed Jul 17 09:14:24 2024 +0200 fix: `Mongo` - Fallback to default filter policy when deserializing retrievers without the init parameter (#899) * Add defensive check for filter_policy deserialization * black test * Fix ruff * Black tests commit b1201e01cd9c3d84d8d1abb646bec0be8441a4f7 Author: HaystackBot Date: Tue Jul 16 09:48:13 2024 +0000 Update the changelog commit ecaeedd3ab192bd62ad4e21a53e7201e1efac0ac Author: Guillaume Chérel Date: Tue Jul 16 11:44:57 2024 +0200 feat: Add metadata parameter to ChromaDocumentStore. (#906) * feat: Add metadata parameter to ChromaDocumentStore. * Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py Co-authored-by: Stefano Fiorucci * Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py Co-authored-by: Stefano Fiorucci * Update integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py Co-authored-by: Stefano Fiorucci * test: update logging message in chroma document store tests * style: Fix formatting * test: Add test for logging messages when creating chroma collection with the same name. * test: Fix logging message. --------- Co-authored-by: Stefano Fiorucci commit 1b3b36e73661ca4fe6694419a9ba4a947c09f5b9 Author: Vladimir Blagojevic Date: Tue Jul 16 10:47:14 2024 +0200 fix: `pinecone` - Fallback to default filter policy when deserializing retrievers without the init parameter (#901) * Add defensive check for filter_policy deserialization * Add comment commit b33505a1fc6f653c15d59bf8e94a986f831a07a9 Author: Vladimir Blagojevic Date: Tue Jul 16 10:46:54 2024 +0200 fix: `ElasticSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#898) * Add defensive check for filter_policy deserialization * Add defensive check for filter_policy deserialization * Add unit test * Revert change in chroma * Linter fix commit 6349d155cfa57b9d26c6b6ede106a86d72445b47 Author: HaystackBot Date: Mon Jul 15 16:31:30 2024 +0000 Update the changelog commit 52345424db8232e353516577a96b5323b276d6c2 Author: Vladimir Blagojevic Date: Mon Jul 15 18:29:35 2024 +0200 fix: `qdrant` - Fallback to default filter policy when deserializing retrievers without the init parameter (#902) * Add defensive check for filter_policy deserialization * Add unit tests commit 05a21f63279517a973360dfc792ae51438fe5e5f Author: HaystackBot Date: Mon Jul 15 15:59:14 2024 +0000 Update the changelog commit 16b38492d05eff489031d99856667154aa2b88b8 Author: Vladimir Blagojevic Date: Mon Jul 15 17:57:36 2024 +0200 fix: `Chroma` - Fallback to default filter policy when deserializing retrievers without the init parameter (#897) * Add defensive check for filter_policy deserialization * Add unit test * Fix test commit 43ccd3cea5bd5430c0a2c925810a40b62601144f Author: HaystackBot Date: Mon Jul 15 15:56:18 2024 +0000 Update the changelog commit bcdf33d979528405cde73a2125dfac98e96630d7 Author: Vladimir Blagojevic Date: Mon Jul 15 17:49:01 2024 +0200 fix: `Astra` - Fallback to default filter policy when deserializing retrievers without the init parameter (#896) * Add defensive check for filter_policy deserialization * Add unit test * Add comment commit a9da4ed55435608ce9875ff885e5ad27b9e01027 Author: HaystackBot Date: Mon Jul 15 15:14:56 2024 +0000 Update the changelog commit b23ab153c86c8724edfb341fce680f5a38004162 Author: Madeesh Kannan Date: Mon Jul 15 17:13:15 2024 +0200 fix: `OpenSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#895) commit 20221ac6c53e3d9dedcec9a2da3289178bf7c495 Author: HaystackBot Date: Mon Jul 15 14:20:05 2024 +0000 Update the changelog commit 90255b472e556aec861432ffc10a0a3d06913d98 Author: HaystackBot Date: Mon Jul 15 13:57:20 2024 +0000 Update the changelog commit 140015b8c81918a7fe4726d89a43c1acbbef0907 Author: David S. Batista Date: Mon Jul 15 15:27:11 2024 +0200 Update README.md (#893) updating Amazon Bedrock link commit dfebd7d6e60e2c161f1790db7f4391a1d880460d Author: HaystackBot Date: Mon Jul 15 09:57:17 2024 +0000 Update the changelog commit c8f19a2fb734e6284dca6bf042c964f67a42b051 Author: Amna Mubashar Date: Wed Jul 10 10:52:35 2024 +0200 fix: errors in convert_filters_to_qdrant (#870) * progress * Fixed logic error * Some tests are still failing * Passed all tests * Fixed errors in logic * Fixed linting issues * Minor adjustments * Further improvements in code structure * Final changes for review * Updated * Added more tests * Add a test to check nested filters * Minor changes * Fix bugs and add docstrings --------- Co-authored-by: Amna Mubashar commit 1ecfbfa6d08f24b1bd24ff83b6ae6941e40ab352 Author: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Tue Jul 9 15:54:39 2024 +0200 Fix Google AI tests failing (#885) * Fix Google AI tests failing * Fix GoogleAIGeminiChatGenerator to_dict and from_dict commit abfe76e5f2c0193736beab0b42b69af19fb0934d Author: Amna Mubashar Date: Mon Jul 8 19:00:43 2024 +0200 fix: Allow search in ChromaDocumentStore without metadata (#863) * Fix a bug for checking metadata values commit 5ab3a10c68f4834b72af74d2d996cf0fe6359f43 Author: Vladimir Blagojevic Date: Mon Jul 8 17:20:37 2024 +0200 Minor retriever pydoc fix (#884) commit 11a07449734c504de821a551a238c086f1a3d0e8 Author: Vladimir Blagojevic Date: Fri Jul 5 16:35:34 2024 +0200 feat: Add filter_policy to chroma integration (#826) * Add filter_policy to chroma integration commit 124b6e801f4cfff38b65c26eda033fee5fa23d08 Author: Vladimir Blagojevic Date: Fri Jul 5 16:35:14 2024 +0200 feat: Add filter_policy to mongodb_atlas integration (#823) * Add filter_policy to mongodb_atlas integration commit 7c60cbf204f1edd039285f460920ef09c5a8243f Author: Vladimir Blagojevic Date: Fri Jul 5 16:34:55 2024 +0200 feat: Add filter_policy to opensearch integration (#822) * Add filter_policy to opensearch integration commit e6d378a6235d13170d66e93db3f4303d2bb26cc6 Author: Vladimir Blagojevic Date: Fri Jul 5 16:34:32 2024 +0200 feat: Add filter_policy to qdrant integration (#819) * Add filter_policy to qdrant integration commit 2182edb01a0a3a4379a72e0fdf94cd3640a0fa9c Author: Vladimir Blagojevic Date: Fri Jul 5 15:40:11 2024 +0200 feat: Add filter_policy to elasticsearch integration (#825) * Add filter_policy to elasticsearch integration commit dff8518588374f67ab55c81a65ec23076710be40 Author: Vladimir Blagojevic Date: Fri Jul 5 13:00:06 2024 +0200 feat: Add filter_policy to pinecone integration (#821) * Add filter_policy to pinecone integration commit cf792d78151967378505ae7b0581ebfa280f7bdb Author: Vladimir Blagojevic Date: Fri Jul 5 12:26:40 2024 +0200 feat: Add filter_policy to pgvector integration (#820) * Add filter_policy to pgvector integration --------- Co-authored-by: Stefano Fiorucci commit 978b71d4dbd9d19c2e616d80c686e48cf876a3da Author: Vladimir Blagojevic Date: Fri Jul 5 12:26:10 2024 +0200 feat: Add filter_policy to weaviate integration (#824) * Add filter_policy to weaviate integration commit 24c56e96b7b74b2780dab8620a6a232ccb99cde9 Author: Vladimir Blagojevic Date: Fri Jul 5 12:25:38 2024 +0200 feat: Add filter_policy to astra integration (#827) * Add filter_policy to astra integration --------- Co-authored-by: Stefano Fiorucci commit 52b0a1a59187c1b0001c2c07351438554e0788b1 Author: David Basoco Date: Thu Jul 4 15:18:51 2024 +0200 Fix not equal astra filter operator (#868) commit ab68d24a688228d8d37e5c8e3f9264d9c5fcdeee Author: HaystackBot Date: Thu Jul 4 08:54:14 2024 +0000 Update the changelog commit fcbc35b972d8a06dbe5c97edb11f308297736fa0 Author: Madeesh Kannan Date: Thu Jul 4 10:51:25 2024 +0200 fix: Fix typo in the `ORTModel.inputs_names` field to align with upstream (#866) commit 0d89e832f84354be0c19464cb9915a8d4aa46a3c Author: HaystackBot Date: Wed Jul 3 13:52:38 2024 +0000 Update the changelog commit 87bb97dffc2aaebd9238866c11fac57c2332fa9f Author: HaystackBot Date: Wed Jul 3 11:10:43 2024 +0000 Update the changelog commit 0fd154b97a8621e29f5f70010cc9f025d7eed245 Author: Stefano Fiorucci Date: Wed Jul 3 13:06:55 2024 +0200 feat: Qdrant - add support for BM42 (#864) * Qdrant: add support for BM42 * add test for sparse configuration commit fd0059e8ce8dc0db338321d2ac5f92b0c1be985a Author: Isaac Chung Date: Wed Jul 3 13:49:24 2024 +0300 feat: add `score_threshold` to Qdrant Retrievers (#860) * feat: add score_threshold to qdrant retrievers * test: add score_threshold to qdrant tests * ruff linting * hatch run lint:all * add test case using score_threshold * linting * test: new test case with fixed embeds per review * expand docstrings * small fixes --------- Co-authored-by: anakin87 commit f73c3514c13cdaaf10a878e88f5987529958e9d2 Author: HaystackBot Date: Tue Jul 2 09:13:17 2024 +0000 Update the changelog commit 9c86675bb00d587410335b6b32c28c8b94a1c795 Author: Stefano Fiorucci Date: Tue Jul 2 11:01:01 2024 +0200 refactor!: Qdrant - set `scale_score` default value to `False` (#862) * rm unused params * qdrant - set scale_score to False commit 06d77769199607c717cffa297d9b71e51bee4ed4 Author: Stefano Fiorucci Date: Tue Jul 2 09:50:43 2024 +0200 refactor!: Qdrant - remove unused init parameters: `content_field`, `name_field`, `embedding_field`, and `duplicate_documents` (#861) * rm unused params * docs: change duplicate_documents to policy in docstring --------- Co-authored-by: Julian Risch commit 268b487a2e8633acecc917e51746eafb2040a9a6 Author: Amna Mubashar Date: Tue Jul 2 01:45:31 2024 +0200 feat: made truncation optional for BedrockGenerator (#833) * Added truncate parameter to init method * fixed serialization bug for BedrockGenerator * Add a test to check truncation functionality commit 2d93ea3abf2141bca2395ead36d28d9dff8bb413 Author: HaystackBot Date: Mon Jul 1 18:44:29 2024 +0000 Update the changelog commit 23c3e108cad5aae08c1deea8c6cc1162d0e900df Author: Amna Mubashar Date: Mon Jul 1 15:33:46 2024 +0200 Add system files in git ignore (#858) Co-authored-by: Amna Mubashar commit 7127be63c9aa152b5139c1e826f7d345f912b178 Author: Amna Mubashar Date: Mon Jul 1 01:50:01 2024 +0200 feat: added distance_function property to ChromadocumentStore (#817) * Added the distance metric property --------- Co-authored-by: Amna Mubashar Co-authored-by: Stefano Fiorucci commit 6d8ce95005ffedcfa76347e0d06820f2c8490092 Author: HaystackBot Date: Fri Jun 28 15:49:09 2024 +0000 Update the changelog commit 1f582d7dee069209696084236438a0f09e2b8bbf Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Fri Jun 28 12:40:45 2024 +0200 feat: add raise_on_failure param to OpenSearch retrievers (#852) * feat: add ignore_errors param to OpenSearch retrievers * feedback * fix ruff command * fix except commit 605ba29f0ec722c103085e83ac9bd32762296105 Author: HaystackBot Date: Fri Jun 28 10:30:15 2024 +0000 Update the changelog commit d2fd97af5d5a28b9d56f233f45b5c6664243c8ee Author: Stefano Fiorucci Date: Fri Jun 28 12:26:04 2024 +0200 build: add `psutil` dependency to Unstructured integration (#854) * add psutil dependency to unstructured integration * fxi commit 9aafa795935a4a70614d43c7afad0be0e2b12c28 Author: Vladimir Blagojevic Date: Fri Jun 28 11:32:10 2024 +0200 chore: Update ruff invocation to include check parameter (#853) * Update ruff invocation to include check parameter * fix linting Sagemaker * unused import --------- Co-authored-by: anakin87 commit 439945410b7968759427e02ff81deaccca313f10 Author: HaystackBot Date: Thu Jun 27 11:32:56 2024 +0000 Update the changelog commit 0039a486f6eb01f2730238a62b76253b185efa5f Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Thu Jun 27 13:28:21 2024 +0200 fix: support legacy filters with OpenSearchDocumentStore (#850) * feat: support legacy filters with OpenSearchDocumentStore * add tests commit f170ab434711d24a66c9c1a63575c9968e54824d Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Wed Jun 26 19:13:32 2024 +0200 fix: serialization for custom_query in OpenSearch retrievers (#851) commit 49e323f037527d0552ddd387ac42d8b888884782 Author: paulmartrencharpro <148542350+paulmartrencharpro@users.noreply.github.com> Date: Wed Jun 26 14:19:53 2024 +0200 Fix: typo on Sparse embedders. The parameter should be "progress_bar" … (#814) * Fix typo on Sparse embedders. The parameter should be "progress_bar" instead of "show_progress_bar" * Fix typo on Sparse embedders tests. The parameter should be "progress_bar" instead of "show_progress_bar" commit bd21df73ed9d298c88e538cfc2f1ede191722863 Author: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed Jun 26 10:51:19 2024 +0200 chore(deps): bump actions/add-to-project from 1.0.1 to 1.0.2 (#849) Bumps [actions/add-to-project](https://github.com/actions/add-to-project) from 1.0.1 to 1.0.2. - [Release notes](https://github.com/actions/add-to-project/releases) - [Commits](https://github.com/actions/add-to-project/compare/v1.0.1...v1.0.2) --- updated-dependencies: - dependency-name: actions/add-to-project dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> commit 6945503db2f72089dd6082a4095755f80dd3dcbc Author: Julian Risch Date: Wed Jun 26 08:50:16 2024 +0200 feat: Use non-gated tokenizer as fallback for mistral in AmazonBedrockChatGenerator (#843) * feat: Use non-gated tokenizer as fallback for mistral * formatting * fix linter issues commit 05ccdb2a9755fc00cb10381fcf1e6c82c7b12cd2 Author: HaystackBot Date: Tue Jun 25 16:11:17 2024 +0000 Update the changelog commit 69c29a95d82a7f2bbd53949257a926351521879e Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue Jun 25 17:24:10 2024 +0200 feat: add custom_query param to OpenSearch retrievers (#841) * feat: add custom_query param to OpenSearch retrievers * feat: add custom_query to OpenSearch retrievers * add as run param * fix lint * switch to jinja2 templates * Revert "switch to jinja2 templates" This reverts commit f36ed13fa25abc5d17df7e087841a9ecf839c75f. * support custom_query as dict * remove unneccessary comments * remove str * fix lint commit be09adf256a107536d5d3b5434bdb70deefa58e3 Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue Jun 25 17:19:02 2024 +0200 feat: add create_index option to OpenSearchDocumentStore (#840) * [opensearch] feat: add create_index option * fix lint * fix lint * add create_index() method * fix lint * better match * fix docs commit 1c557cb06a3e19338e1ff4bd53494bb69c83695d Author: HaystackBot Date: Tue Jun 25 13:17:50 2024 +0000 Update the changelog commit 53f26ec92e233fa0e8dfa7757ae9236268086cd3 Author: Stefano Fiorucci Date: Tue Jun 25 10:43:21 2024 +0200 update Pinecone test scripts (#848) commit 60c666624cfb20a3600a28be074697fdf4cf2553 Author: Stefano Fiorucci Date: Mon Jun 24 16:31:47 2024 +0200 install pytest-rerunfailures; change test-cov script (#845) commit 7723ddee7c17356f1393dc15b48711c73ba0663a Author: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon Jun 24 15:35:30 2024 +0200 fix: fix connection to Weaviate Cloud Service (#624) * Fix connection to Weaviate Cloud Service * Handle connection to WCS and add tests * Add comment explaining why we use utility function commit 14499cd48a0c766b10a79814c5d3c9bf8d1e6bb2 Author: Silvano Cerza Date: Mon Jun 24 13:02:12 2024 +0200 Revert "Handle connection to WCS and add tests" This reverts commit f48802b2ce612896fd06a13cf33dffd9f77a8859. commit f48802b2ce612896fd06a13cf33dffd9f77a8859 Author: Silvano Cerza Date: Mon Jun 24 13:01:27 2024 +0200 Handle connection to WCS and add tests commit 75bb792e956f038992be8b8d84c4b7b94e42b9db Author: HaystackBot Date: Fri Jun 21 15:01:10 2024 +0000 Update the changelog commit 7cd71f17a3f86a22f74e85db2825f993c597e5b8 Author: Vladimir Blagojevic Date: Fri Jun 21 16:59:14 2024 +0200 feat: Update Anthropic default models, pydocs (#839) * Update default models, pydocs * Update unit test commit 1f7f75a663292df2cc322c526a80be0fbba7cdee Author: Vladimir Blagojevic Date: Fri Jun 21 14:03:29 2024 +0200 feat: Update Cohere default LLMs, add examples and update unit tests (#838) * Update default models, examples and unit tests * PR feedback commit 8152b6a3c29eb3ebd0727646b50b0bd71d791ebe Author: Massimiliano Pippi Date: Fri Jun 21 13:20:39 2024 +0200 ci: retry tests to reduce flakyness (#836) * test the strategy * cargo load the whole bunch commit 4b3abda8eec9eaa43c87eb37427c69a78760ba20 Author: Massimiliano Pippi Date: Fri Jun 21 09:57:23 2024 +0200 clean up workflow files (#835) commit b11486fb7a20728ed89b52243200cc8184272ab8 Author: HaystackBot Date: Thu Jun 20 13:35:54 2024 +0000 Update the changelog commit e42f8f2935676a55ed0c596baac5f026e8195fbb Author: Amna Mubashar Date: Thu Jun 20 13:19:43 2024 +0200 doc: added docstrings for QdrantDocumentStore (#808) * doc: add docstrings qdrant document store * Updated docstrings based on PR review --------- Co-authored-by: Amna Mubashar commit a0eca9ae434f4774ed7401700879f9a96db8014a Author: HaystackBot Date: Thu Jun 20 08:06:17 2024 +0000 Update the changelog commit d161220395aabf45fdc049ee7b129448406e39d6 Author: Leonardo Teixeira Menezes Date: Thu Jun 20 10:02:31 2024 +0200 feat: add customizable index names for pgvector (#818) * feat: add customizable index names for pgvector * refactor: Remove unnecessary constants on PGVector document store commit 867b2a3b66eb75e735726f2b9780778ec59cd33b Author: Vedant Naik <52022480+vedantnaik19@users.noreply.github.com> Date: Wed Jun 19 15:55:41 2024 +0100 fix: weaviate filter error (#811) * fix: weaviate filter error * test: add test for legacy filters for weaviate commit 236fd28287c42d34b0f0cb48426625f3e747d797 Author: HaystackBot Date: Wed Jun 19 04:15:36 2024 +0000 Update the changelog commit b9f783dfd7d3b18e749889a1031173d9643940c9 Author: HaystackBot Date: Tue Jun 18 15:33:22 2024 +0000 Update the changelog commit 70e2a9cf9315e01c292105acdf3eecb150793336 Author: Massimiliano Pippi Date: Tue Jun 18 17:32:04 2024 +0200 add support for Azure generators (#815) commit 6254d58019101b7ec9d239fe9ee7b3adc504c38a Author: Vladimir Blagojevic Date: Tue Jun 18 10:27:04 2024 +0200 feat: Update Anthropic/Cohere for tools use (#790) * Update for tools use * Test updates * Add tools usage integration tests * Minor test detail update * PR review * Pydocs update commit 14a2711079c565b4eb7baaccbf1f3f31c7a646b8 Author: HaystackBot Date: Fri Jun 14 15:09:05 2024 +0000 Update the changelog commit 5e66f1d370cc33d6e4f29019bfdc368fc63182c8 Author: tstadel <60758086+tstadel@users.noreply.github.com> Date: Fri Jun 14 17:06:51 2024 +0200 feat: support Claude v3, Llama3 and Command R models on Amazon Bedrock (#809) * feat: support Claude v3 and Cohere Command R models on Amazon Bedrock * revert chat pattern change * rename llama adapter * fix tests after llama adapter rename commit 590e2b016009d171eeb9b3aae95467f295936453 Author: HaystackBot Date: Thu Jun 13 10:27:46 2024 +0000 Update the changelog commit bf5c64138cfcb0feea173d4e2289754922ae5a41 Author: agruhl Date: Thu Jun 13 12:26:05 2024 +0200 fix: Performance optimizations and value error when streaming in langfuse (#798) * Solves issue with usage stats when streaming is enabled on the OpenAIGenerator * Root span should be closed when the pipeline run is complete * Added documentation * Moved flushing execution to the last span in the context and improved the documentation to give examples of flushing properly manually * Fixed linting issues * make use of monkeypatch * improving code commit 575e209f7020a327b3347368eb29e2cc9d01ac71 Author: Stefano Fiorucci Date: Wed Jun 12 19:17:53 2024 +0200 tests: Pinecone - fix `test_serverless_index_creation_from_scratch` (#806) * pinecone tests: wait for index creation * make index name dynamic * lint * run the test only once in our suite * add reason * increase sleep time * better skipif condition * better coverage options * fix * revert changes in coverage.run * fix * add unit to sleep tima * define index name in the matrix * fix * add default * lint commit 015cc3efc8da4256f75c7ae568bc8e090137a480 Author: HaystackBot Date: Wed Jun 12 09:50:16 2024 +0000 Update the changelog commit 9f5a4601018c1a5707da02c43ebbd75b43d78640 Author: Stefano Fiorucci Date: Wed Jun 12 11:48:35 2024 +0200 cohere - remove warning (#805) commit eb0722b90220b868d69764901342badabadee196 Author: Stefano Fiorucci Date: Tue Jun 11 12:54:04 2024 +0200 test: Amazon Bedrock - skip integration tests from forks (#801) * try skipping integration tests from forks * make it work on windows * skip mistral tests when HF token is not set * add windows step * separate unit and int tests * refinement * format commit 4821ff3dbc86a51c652651df6bf626ced1412f67 Author: HaystackBot Date: Tue Jun 11 09:05:42 2024 +0000 Update the changelog commit 7524022524ca38889a39744b2661913ef2552206 Author: Massimiliano Pippi Date: Tue Jun 11 11:03:27 2024 +0200 feat: defer the database connection to when it's needed (#804) * feat: defer the database connection to when it's needed * linting * fix tests commit 29b363c2cea1c677b3d0ba23de4974d9e033b710 Author: HaystackBot Date: Mon Jun 10 16:33:29 2024 +0000 Update the changelog commit f70664dd5d276a801b3544ad8943b95bdd064deb Author: Massimiliano Pippi Date: Mon Jun 10 18:31:06 2024 +0200 feat: defer the database connection to when it's needed (#802) * feat: defer the database connection to when it's needed * remove unneeded noqa * fix fixture * trigger the connection before asserting * trigger connection * make also serialization lazy * remove copypasta leftovers commit 5bc08dfba2be655b149137c2e852d9a58bd8172c Author: Madeesh Kannan Date: Mon Jun 10 17:41:05 2024 +0200 refactor: Remove deprecated Nvidia Cloud Functions backend and related code. (#803) commit 76504ab95bf7bbfaebaa2f06bae3dab892df2ae9 Author: HaystackBot Date: Mon Jun 10 15:03:23 2024 +0000 Update the changelog commit 4dda1b9ec4eea7f48cd2546da4c36be7f4134210 Author: Stefano Fiorucci Date: Mon Jun 10 17:00:36 2024 +0200 feat!: Pinecone - support for the new API (#793) * upgrate to new API * increase sleep time * update example * address feedback commit 1f71c721049f91634a96d61c5b505e66649655c5 Author: HaystackBot Date: Mon Jun 10 08:08:33 2024 +0000 Update the changelog commit 2b4a534e21ca9c0a148011ecd2daac9c02556206 Author: HaystackBot Date: Fri Jun 7 07:13:42 2024 +0000 Update the changelog commit 78157a94e7e2d0cd3ae8ac42508e9a248b96852b Author: lohit8846 Date: Thu Jun 6 00:17:50 2024 -0700 feat: Add force_disable_check_same_thread init param for Qdrant local client (#779) * fix: added missing init param which is use by qdrant local client * fix: Corrected linting commit 3f3fa2bd2185733621b6a85becd4713e0074b7d1 Author: HaystackBot Date: Wed Jun 5 14:46:41 2024 +0000 Update the changelog commit bc032a37a20a38ebc3a6b9bc7fe342d4e3e5411c Author: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed Jun 5 16:44:57 2024 +0200 Fix tests skipping for Google AI integration (#788) commit 48422b1e91e3d71df149ae04a36e2d4e506dcc5c Author: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed Jun 5 16:37:15 2024 +0200 Drop support for Python 3.8 for Unstructured integration (#787) commit ab58a2319fa3df2c51ec68864a5da09105bf1fef Author: antoniomuzzolini <39350879+antoniomuzzolini@users.noreply.github.com> Date: Wed Jun 5 16:32:06 2024 +0200 fix: Handle `TypeError: Could not create Blob` in `GoogleAIGeminiChatGenerator` (#772) * bugfix 654 aligned _message_to_content to VertexAIGeminiChatGenerator * Update test_chat_gemini.py --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> commit 4290d6a3fad63701882fa4b70af8aafb94db7d89 Author: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed Jun 5 16:22:31 2024 +0200 Fix Google AI integration tests (#786) commit 4042fddcfe5682885b711260812a37f6323bf040 Author: ArzelaAscoIi <37148029+ArzelaAscoIi@users.noreply.github.com> Date: Wed Jun 5 14:32:07 2024 +0200 feat: return_embeddings flag for opensearch (#784) * feat: return_embeddings flag for opensearch * fix * fix * fix * tests: bm25 retrieval commit 42d85b3a2936dc1275558d99907b543a24aa595b Author: Matthew Farrellee Date: Wed Jun 5 07:42:44 2024 -0400 align tests and docs on NVIDIA_API_KEY (instead of NVIDIA_CATALOG_API_KEY) (#731) commit e118b926456d1c52ad35ef1a7fa1a1e274487d58 Author: Matthew Farrellee Date: Wed Jun 5 07:42:37 2024 -0400 switch default from NVCF backend to hosted NIM backend (#734) commit 6aaf523636d923ce158c46ea39a057f5529b5a17 Author: HaystackBot Date: Mon Jun 3 07:40:19 2024 +0000 Update the changelog commit f7996741d9b78fce418e85bf886f7c5664e4beda Author: Etienne Date: Fri May 31 10:54:40 2024 -0400 feat: Add streaming support to OllamaChatGenerator (#757) * Add streaming support to OllamaChatGenerator * Clean imports, update docstring * Organize imports * Fix test * Reformat code * Optimize imports * Add test for streaming callback --------- Co-authored-by: Silvano Cerza commit 145751535df0759c5dc4a93801a9e9f474f5a103 Author: HaystackBot Date: Fri May 31 09:57:08 2024 +0000 Update the changelog commit 43c2cf39caa59505cb50e0180dc13819e0ab9940 Author: Stefano Fiorucci Date: Fri May 31 11:54:53 2024 +0200 pin chromadb (#777) commit 8a79bfa0fa4108309ed4c2d16ad7633d5305dd82 Author: Massimiliano Pippi Date: Fri May 31 08:49:06 2024 +0200 fix typing checks commit 5a84d41d2bf2f919931399eddfe634f34e03cd63 Author: HaystackBot Date: Thu May 30 09:21:04 2024 +0000 Update the changelog commit 5098fbdb5c396cada8f5844d0c3a0c98deb04b00 Author: HaystackBot Date: Thu May 30 09:19:41 2024 +0000 Update the changelog commit f805abc9687e6f78aa3b6cbedb37093644a78925 Author: HaystackBot Date: Thu May 30 09:18:13 2024 +0000 Update the changelog commit b07dc1e2372598ecfa86eae2d1cd80ad68f06531 Author: HaystackBot Date: Thu May 30 09:16:57 2024 +0000 Update the changelog commit b966f182f0f55f028a19032c8bda1738f1ee2af9 Author: HaystackBot Date: Thu May 30 09:12:42 2024 +0000 Update the changelog commit 42339cdbdcdc63ded4aaaa760771a610b9406dbb Author: Massimiliano Pippi Date: Thu May 30 11:09:09 2024 +0200 fix: pass empty dict to filter instead of None (#775) commit f68b0dc48d282f7abdf05e83fb3251cb09b1a26f Author: Massimiliano Pippi Date: Thu May 30 09:19:46 2024 +0200 fix astra nightly commit c58b319596c81a7f6af2adf03933cdd8865ce614 Author: Massimiliano Pippi Date: Wed May 29 22:08:35 2024 +0200 show the unreleased changes for an integration (#774) commit e63b2b03fc9e4fdca40afc6cf0e7a7ca3ceb2866 Author: Massimiliano Pippi Date: Wed May 29 14:22:16 2024 +0200 feat: defer the database connection to when it's needed (#773) * fix tests * fix cursor creation commit 588d6549e0772aa069355e91e809aa638b569c05 Author: Massimiliano Pippi Date: Wed May 29 14:22:02 2024 +0200 feat: defer the database connection to when it's needed (#770) * feat: defer the database connection to when it's needed * lazy collection too * add test * linting commit 5eebd8444eb18e24cc448e33aa7761ff630025aa Author: Massimiliano Pippi Date: Wed May 29 10:36:44 2024 +0200 feat: defer the database connection to when it's needed (#769) * feat: defer the database connection to when it's needed * fix typing * test init is lazy commit 7d36d02f47bea45dd8c48235fb0d850949e13aba Author: Massimiliano Pippi Date: Wed May 29 10:36:14 2024 +0200 feat: defer the database connection to when it's needed (#766) * fix tests * fix linting commit 257f99276f35abe68c04e2f616e3669b43fd010f Author: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Mon May 27 09:16:20 2024 +0200 feat: Improve `OpenSearchDocumentStore.__init__` arguments (#739) * add max_chunk_bytes parameter to OpenSearchDocumentStore init * more documentation in opensearch init * fix tests * use `max_chunk_bytes` in delete_documents * re added type ignore * add new kwargs to to_dict * test default mappings * Update document_store.py * restore C++ --------- Co-authored-by: Massimiliano Pippi commit 51d0be00c20bef29090d1d37837b6de830d3b43e Author: Massimiliano Pippi Date: Fri May 24 22:22:50 2024 +0200 allow unconventional commit 8b5d5906b29899e9ca59c158abec6e5f3fd0f840 Author: HaystackBot Date: Fri May 24 19:40:00 2024 +0000 Update the changelog commit f3d6b130ffd7eaf158f1df45baf82c162d7a21f6 Author: Massimiliano Pippi Date: Fri May 24 21:39:15 2024 +0200 use git-cliff-action directly commit 6eac9df34b83728334fc1750631f6d65ba702b4b Author: HaystackBot Date: Fri May 24 19:27:04 2024 +0000 Update the changelog commit 8fbf0876b00b89362a3367974a1d1dfe05a45fdd Author: Massimiliano Pippi Date: Fri May 24 21:26:09 2024 +0200 fetch the full history commit b591fb5b2f92d14e116bef0f9b6723c4035aa73a Author: HaystackBot Date: Fri May 24 19:14:17 2024 +0000 Update the changelog commit d4c87ecbd9c9fbcd0b54aa937bff59fac21c81b6 Author: Massimiliano Pippi Date: Fri May 24 21:13:24 2024 +0200 push to main commit 39be0e750964d0acbaa047b2cdcb6f3cde07a98c Author: Massimiliano Pippi Date: Fri May 24 21:06:42 2024 +0200 introduce EndBug/add-and-commit commit e89f6d888d7afe99f78fe1b4231c8ff832a62a4f Author: Massimiliano Pippi Date: Fri May 24 18:38:15 2024 +0200 restore GITHUB_TOKEN commit 4238980b1e33a3cfa4d520a2bcae21b738df2279 Author: Massimiliano Pippi Date: Fri May 24 17:35:31 2024 +0200 stup git config commit 9c4a17fcfca85d8f76cc2596590470186db9dc71 Author: Massimiliano Pippi Date: Fri May 24 17:33:37 2024 +0200 authenticate checkout step commit 216b73421fcfe29856499aa4653272dd2e3e5ec7 Author: Massimiliano Pippi Date: Fri May 24 17:15:52 2024 +0200 try setting git user explicitly commit 36828c59228e520f944824d053445157cc084e3a Author: Massimiliano Pippi Date: Fri May 24 17:01:29 2024 +0200 fetch before checkout commit d6c7c573492eefdf380a0dcb4cb594bd8f977183 Author: Massimiliano Pippi Date: Fri May 24 16:59:34 2024 +0200 fix branch checkout commit 6ff501154ce3c025e6755446088ed5e3dd988733 Author: Massimiliano Pippi Date: Fri May 24 16:56:04 2024 +0200 setup git config commit 051ca4924c0e6a8f776a9e8ed26df51b2a7e84d3 Author: Massimiliano Pippi Date: Fri May 24 16:53:31 2024 +0200 feat: defer the database connection to when it's needed (#753) * defer the database connection to when it's needed * avoid accessing _client, use the property instead * ignore mypy errors on private field commit d914669fa71f9bce299a17abe177f7c89c793c18 Author: Massimiliano Pippi Date: Fri May 24 16:52:54 2024 +0200 generate integrations changelog commit 38dc95f2b60d011476c1fb3b46895b05c416b472 Author: Ruben <38215798+ruben-vb@users.noreply.github.com> Date: Fri May 24 12:05:43 2024 +0200 feat: make get_distance and recreate_collection public, replace deprecated recreate_collection function (#754) * make get_distance and recreate_collection public * replace deprecated recreate_collection function * make on_disk and use_sparse_embeddings optional for recreate_collection * use client.collection_exists instead of try-catch --------- Co-authored-by: Massimiliano Pippi commit 3a67349dc572269049d65db13d4eed3b5acec637 Author: Massimiliano Pippi Date: Fri May 24 12:03:13 2024 +0200 fix: remove support for generate API (#755) * remove support for generate API * add note about super() commit 29469fa16c4f02b4d0b9d304bcf7e3810bea23e4 Author: Jan Beitner Date: Fri May 24 11:00:12 2024 +0100 Allow vanilla qdrant filters (#692) * allow vanilla qdrant filters * updated signatures * fix formatting * Fix type check in filter_documents and add small test --------- Co-authored-by: Silvano Cerza commit 2667d6bc6077a8db176b257bf178a1df45153214 Author: Anushree Bannadabhavi Date: Fri May 24 05:40:41 2024 -0400 fix: add support for custom mapping in ElasticsearchDocumentStore (#721) * Add custom mapping in ElasticsearchDocumentStore init * Update docstrings and add test * Fix linting * Fix retrievers tests --------- Co-authored-by: Silvano Cerza commit 95daee37d79799f5134c7bed7889865084aae114 Author: Massimiliano Pippi Date: Fri May 24 08:15:03 2024 +0200 feat: defer database connection to the first usage (#748) * defer the database connection to the first usage of the client * add test to avoid regressions commit ee2d54baf5dacf3817a0447b8d34390f108b6d2a Author: mohammedsohel Date: Thu May 23 19:10:42 2024 +0530 adding support of "amazon.titan-embed-text-v2:0" (#735) * adding support of "amazon.titan-embed-text-v2:0" * rectifying the format --------- Co-authored-by: Massimiliano Pippi commit 428c2a8a9c22dc407a203d3ebc502a153a70505b Author: Massimiliano Pippi Date: Thu May 23 11:49:12 2024 +0200 amend PR template and CoC commit 660d73ddb48801bc0ea34f056d11494699d77bd1 Author: Massimiliano Pippi Date: Thu May 23 11:44:17 2024 +0200 add PR template and CoC commit 63cf323abf97c9d7221a610a9dcf2a45c17d3fb3 Author: Vishal Date: Thu May 23 14:33:32 2024 +0530 fix: max_tokens typo in Mistral Chat (#740) commit 31e61b72571dbc699577af2a573e7031aeb379e8 Author: Tuana Çelik Date: Wed May 22 12:30:26 2024 +0200 Update _nim_backend.py (#744) * Update _nim_backend.py * Update _nim_backend.py commit 7141c68c25d64192e8fa7e341e9d6fb5eb3f5612 Author: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Sat May 18 05:12:19 2024 -0400 [deepset-ai/haystack-core-integrations#727] (#738) * hybrid retrieval ex * Update integrations/pgvector/examples/hybrid_retrieval.py Co-authored-by: Stefano Fiorucci * suggested updates * suggested updates * suggested updates --------- Co-authored-by: Stefano Fiorucci commit c4f1cc48c8e83baf12335ec1c35e424d1a0975e9 Author: Stefano Lottini Date: Wed May 15 21:49:21 2024 +0200 explicit projection when reading from Astra DB (#733) commit 6f298cec4c3c374f714e9eb753162026666cb525 Author: paulmartrencharpro <148542350+paulmartrencharpro@users.noreply.github.com> Date: Wed May 15 18:11:42 2024 +0200 Use the local_files_only option available as of fastembed==0.2.7. It … (#736) * Use the local_files_only option available as of fastembed==0.2.7. It allows to not look for the models online, but only use the local, cached, files. This way, we can download the model once then use this without internet access * Fix lint issues * add same param to doc embedder --------- Co-authored-by: anakin87 commit 0e02fd65f432f10c77fcd066a7064bf5ba7223a3 Author: Ulises M <30765968+lbux@users.noreply.github.com> Date: Mon May 13 05:44:45 2024 -0700 basic implementation of llama.cpp chat generation (#723) * basic implementation of llama.cpp chat generation allows for constraining to json allows for function calling (not tested) streaming needs to be implemented when stream is set to true in generation_kwargs * add testing * remove unnecessary function * slight documentation fix, comment out broken test * support for function calling through functionary also add a basic rag test * add function calling and execute test, it works! * add json test, add chatml test * make function call and execute more deterministic * try removing additional deps * revert * make transformers a tests-only dependency --------- Co-authored-by: Stefano Fiorucci commit d4a598b6f5d3287ead9e8c234b001ea0f15e376b Author: Jon Date: Fri May 10 00:30:48 2024 -0700 Implement filters for chromaQueryTextRetriever via existing haystack filters logic (#705) * Implement filters for chromaQueryTextRetriever via existing haystack filters logic Run linter * un-skip tests --------- Co-authored-by: Massimiliano Pippi commit c29db9c913c8326c28d2f3de091429d95ee73b38 Author: Daria Fokina Date: Thu May 9 19:10:51 2024 +0200 missing api references (#728) commit 7b4428d934e4304f186cd01e02bcf239a43f5729 Author: Vladimir Blagojevic Date: Wed May 8 17:30:52 2024 +0200 chore: Use ChatMessage to_openai_format, update unit tests, pydocs (#725) * Use ChatMessage to_openai_format, update unit tests, pydocs * Minor pydocs fixes, turn off integration tests for nightly runs * Run only unit tests against haystack-ai main nightly commit 758e5f372bea65b6ffa149c06a3ab2431d7e7237 Author: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Wed May 8 04:10:55 2024 -0400 feat: Implement keyword retrieval for pgvector integration (#644) * keyword retriever * add lang to init for tests * make suggested edits/test * fixes to test / lint * index check query change * SQLLiteral fix * table name quotes * table name quotes * table name quotes * test query edit * remove meta * move keyword index to init * move keyword index to init * move keyword index to init * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * keyword with filters test * more tests * rename example --------- Co-authored-by: anakin87 commit 3c14c52fc38903ec907d28da384b0a36119a3892 Author: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Tue May 7 17:03:50 2024 +0200 Update Nvidia integration to support new endpoints (#701) * Add support for Nvidia catalog API for generator * Add support for Nvidia catalog API for embedders * Add NVIDIA_CATALOG_API_KEY in Nvidia integration workflow * Enable ruff auto formatting for tests * Fix linting * Simplify Secret import and enhance docstring Co-authored-by: Madeesh Kannan * Add deprecation warnings for NvcfBackend * Add truncate parameter for embedders * Fix linting * Use enum for truncate mode in embedders * Change how truncate argument is handled * Fix truncate conversion * Update truncate docstring --------- Co-authored-by: Madeesh Kannan commit 975e0e528cb4f01a92bc95905fe948674c47aaed Author: Massimiliano Pippi Date: Tue May 7 10:57:05 2024 +0200 fix: make unit tests pass (#720) * make unit tests pass * linting commit 8a1242f64e7825238101b646f687ced76177f397 Author: Vladimir Blagojevic Date: Mon May 6 21:04:54 2024 +0200 Fix langfuse nightly tests (#716) commit f61db6d3a24cbd9bb5327a8765bf0c8272639754 Author: Massimiliano Pippi Date: Mon May 6 17:56:21 2024 +0200 change the pydoc renderer class (#718) commit e5667e78b222135dbaddb77f51d6c096650cdd38 Author: Massimiliano Pippi Date: Mon May 6 17:54:08 2024 +0200 pass the haystack docs version when generating docs (#719) commit d30c0eafef3dd399e37a8e4168c90a418e6e5716 Author: Stefano Fiorucci Date: Mon May 6 14:58:23 2024 +0200 FastembedTextEmbedder - remove batch_size (#688) commit da46c9c5069ad78c75456dfa2f8684e7ba37d55e Author: Massimiliano Pippi Date: Mon May 6 09:58:12 2024 +0200 Update README.md commit 9659b1305f8750ff1c8a675d054f3a7c360fd0af Author: Dmitry Date: Mon May 6 10:43:57 2024 +0300 Type hints in pgvector document store updated for 3.8 compability (#704) commit 04fb950f182fb9d69af3bf42fee26d32206a9a27 Author: Massimiliano Pippi Date: Fri May 3 19:01:36 2024 +0200 fix: add multi-line variable to step output in the right way (#714) * try * try * remove testing code commit 8b916a33ffc25e549ae52728515daf30fa1da5bb Author: Massimiliano Pippi Date: Fri May 3 17:53:30 2024 +0200 Follow up: update Cohere integration to use Cohere SDK v5 (#711) * add support for python client v5 * linting commit 1d0a5568178340bf4abbdaeecd7b690ebfc7ddec Author: Massimiliano Pippi Date: Fri May 3 17:46:25 2024 +0200 chore: sync integrations docs with all the available Haystack versions (#713) * sync docs with all the versions * automatically fetch versions * Update CI_readme_sync.yml Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> commit 5c7e1b74e58ddb7943999a8f0ca24cfb0d834884 Author: Chris Knight Date: Fri May 3 10:53:45 2024 +0100 fix: Weaviate schema class name conversion which preserves PascalCase (#707) * fix: Weaviate schema class name conversion which preserves PascalCase * adding a test case for schema name conversion * linting * formatting --------- Co-authored-by: David S. Batista commit ca87b0edcf607daf80158c13ec182a4f578b6667 Author: Jon Date: Thu May 2 09:46:22 2024 -0700 Feature/bump chromadb dep to 0.5.0 (#700) * [DEP]: Bump chromadb version Remove specific version pin * Fix: chromadb tests failing due to sort order variance Fix: linting error * Add OllamaEmbeddingFunction to function registry Linting commit 48521156d9fa1af1afb332395ee430d58fcb9428 Author: Vladimir Blagojevic Date: Thu May 2 18:08:19 2024 +0200 feat: Langfuse integration (#686) * Add langfuse integration * Trace pipeline run * Integration admin additions * Pydoc config * Capture trace url in tracer component * Add integration test, update example * Linting * Add haystack-pydoc-tools dep * Add comprehensive README * Handle both ChatMessage and str payloads * Renaming * Versioning scheme * Pydocs, add public trace flag * Add hatch-vcs dep * Use OPENAI_API_KEY secret * update docstrings * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * lint fixups * Improve test, previous version always returned 200 * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Update integrations/langfuse/README.md Co-authored-by: Daria Fokina * Add details about Langfuse keys * Pylint --------- Co-authored-by: Massimiliano Pippi Co-authored-by: Daria Fokina commit 26a16b69398e3dc9894f2a853ba817c728805062 Author: Julian Risch Date: Thu May 2 13:06:36 2024 +0200 docs: Add Ranker type to Cohere and Jina in inventory (#708) commit 667162fe6a08dc052063919f52c40479cb574226 Author: Alex W Date: Thu May 2 04:50:17 2024 -0400 [Cohere] feat: Update integration to use Cohere SDK v5 (#702) commit 1790fc5718544125144d2d4475f18c24590490ab Author: Stefano Fiorucci Date: Fri Apr 26 09:25:30 2024 +0200 qdrant - improve docstrings for retrievers (#687) commit 9ed323080bcc82ff70e9a0f56a6c2561fec42c2e Author: Florian Rumiel Date: Wed Apr 24 13:11:52 2024 -0400 Fix streaming_callback serialization in AmazonBedrockChatGenerator (#685) --- .github/labeler.yml | 10 +- .github/pull_request_template.md | 23 + .github/workflows/CI_project.yml | 2 +- .github/workflows/CI_pypi_release.yml | 22 + .github/workflows/CI_readme_sync.yml | 32 +- .github/workflows/amazon_bedrock.yml | 21 +- .github/workflows/amazon_sagemaker.yml | 16 +- .github/workflows/anthropic.yml | 13 +- .github/workflows/astra.yml | 78 ++- .github/workflows/chroma.yml | 74 ++- .github/workflows/cohere.yml | 16 +- .github/workflows/deepeval.yml | 14 +- .github/workflows/elasticsearch.yml | 16 +- .github/workflows/fastembed.yml | 20 +- .github/workflows/google_ai.yml | 14 +- .github/workflows/google_vertex.yml | 14 +- .github/workflows/gradient.yml | 80 --- .github/workflows/instructor_embedders.yml | 14 +- .github/workflows/jina.yml | 14 +- .github/workflows/langfuse.yml | 77 +++ .github/workflows/llama_cpp.yml | 74 ++- .github/workflows/mistral.yml | 16 +- .github/workflows/mongodb_atlas.yml | 16 +- .github/workflows/nvidia.yml | 11 +- .github/workflows/ollama.yml | 24 +- .github/workflows/opensearch.yml | 20 +- .github/workflows/optimum.yml | 14 +- .github/workflows/pgvector.yml | 24 +- .github/workflows/pinecone.yml | 22 +- .github/workflows/qdrant.yml | 14 +- .github/workflows/ragas.yml | 16 +- .github/workflows/unstructured.yml | 22 +- .github/workflows/weaviate.yml | 14 +- .gitignore | 9 + README.md | 42 +- cliff.toml | 84 +++ integrations/amazon_bedrock/CHANGELOG.md | 106 ++++ integrations/amazon_bedrock/pydoc/config.yml | 2 +- integrations/amazon_bedrock/pyproject.toml | 71 +-- .../amazon_bedrock/document_embedder.py | 14 +- .../embedders/amazon_bedrock/text_embedder.py | 14 +- .../generators/amazon_bedrock/adapters.py | 187 ++++-- .../amazon_bedrock/chat/adapters.py | 111 ++-- .../amazon_bedrock/chat/chat_generator.py | 111 ++-- .../generators/amazon_bedrock/generator.py | 128 ++-- .../generators/amazon_bedrock/handlers.py | 33 - .../tests/test_chat_generator.py | 182 +++++- .../amazon_bedrock/tests/test_generator.py | 517 ++++++++++++---- integrations/amazon_sagemaker/CHANGELOG.md | 27 + .../amazon_sagemaker/pydoc/config.yml | 2 +- integrations/amazon_sagemaker/pyproject.toml | 67 +- .../amazon_sagemaker/tests/test_sagemaker.py | 17 +- integrations/anthropic/CHANGELOG.md | 28 + .../example/documentation_rag_with_claude.py | 9 +- integrations/anthropic/pydoc/config.yml | 4 +- integrations/anthropic/pyproject.toml | 75 +-- .../anthropic/chat/chat_generator.py | 49 +- .../generators/anthropic/generator.py | 2 +- .../anthropic/tests/test_chat_generator.py | 66 +- integrations/astra/CHANGELOG.md | 91 +++ integrations/astra/README.md | 52 +- integrations/astra/pydoc/config.yml | 2 +- integrations/astra/pyproject.toml | 90 +-- .../components/retrievers/astra/retriever.py | 36 +- .../document_stores/astra/astra_client.py | 45 +- .../document_stores/astra/document_store.py | 25 +- .../document_stores/astra/filters.py | 4 +- .../astra/tests/test_document_store.py | 48 +- integrations/astra/tests/test_retriever.py | 58 ++ integrations/chroma/CHANGELOG.md | 100 +++ integrations/chroma/pydoc/config.yml | 2 +- integrations/chroma/pyproject.toml | 70 +-- .../components/retrievers/chroma/retriever.py | 41 +- .../document_stores/chroma/document_store.py | 105 +++- .../document_stores/chroma/utils.py | 2 + .../chroma/tests/test_document_store.py | 133 +++- integrations/chroma/tests/test_retriever.py | 47 ++ integrations/cohere/CHANGELOG.md | 108 ++++ .../cohere/examples/cohere_embedding.py | 28 + .../cohere/examples/cohere_generation.py | 53 ++ ...nker_in_a_pipeline.py => cohere_ranker.py} | 7 +- integrations/cohere/pydoc/config.yml | 2 +- integrations/cohere/pyproject.toml | 21 +- .../embedders/cohere/document_embedder.py | 14 +- .../embedders/cohere/text_embedder.py | 14 +- .../components/embedders/cohere/utils.py | 51 +- .../generators/cohere/chat/chat_generator.py | 138 +++-- .../components/generators/cohere/generator.py | 155 +---- .../components/rankers/cohere/ranker.py | 4 +- .../tests/test_cohere_chat_generator.py | 152 ++--- ...generators.py => test_cohere_generator.py} | 39 +- .../cohere/tests/test_cohere_ranker.py | 2 +- .../cohere/tests/test_document_embedder.py | 8 +- .../cohere/tests/test_text_embedder.py | 8 +- integrations/deepeval/pydoc/config.yml | 2 +- integrations/deepeval/pyproject.toml | 14 +- integrations/elasticsearch/CHANGELOG.md | 90 +++ integrations/elasticsearch/pydoc/config.yml | 2 +- integrations/elasticsearch/pyproject.toml | 77 +-- .../elasticsearch/bm25_retriever.py | 19 +- .../elasticsearch/embedding_retriever.py | 19 +- .../elasticsearch/document_store.py | 83 ++- .../tests/test_bm25_retriever.py | 37 ++ .../tests/test_document_store.py | 45 +- .../tests/test_embedding_retriever.py | 34 + integrations/fastembed/CHANGELOG.md | 63 ++ integrations/fastembed/pydoc/config.yml | 2 +- integrations/fastembed/pyproject.toml | 72 +-- .../embedding_backend/fastembed_backend.py | 18 +- .../fastembed/fastembed_document_embedder.py | 21 +- .../fastembed_sparse_document_embedder.py | 34 +- .../fastembed_sparse_text_embedder.py | 32 +- .../fastembed/fastembed_text_embedder.py | 19 +- .../fastembed/tests/test_fastembed_backend.py | 4 +- .../tests/test_fastembed_document_embedder.py | 12 +- ...test_fastembed_sparse_document_embedder.py | 14 +- .../test_fastembed_sparse_text_embedder.py | 15 +- .../tests/test_fastembed_text_embedder.py | 5 +- integrations/google_ai/CHANGELOG.md | 43 ++ integrations/google_ai/pydoc/config.yml | 2 +- integrations/google_ai/pyproject.toml | 68 +- .../generators/google_ai/chat/gemini.py | 55 +- .../components/generators/google_ai/gemini.py | 14 +- .../tests/generators/chat/test_chat_gemini.py | 93 ++- .../google_ai/tests/generators/test_gemini.py | 16 +- integrations/google_vertex/pydoc/config.yml | 2 +- integrations/google_vertex/pyproject.toml | 69 +-- integrations/gradient/LICENSE.txt | 201 ------ integrations/gradient/README.md | 22 - .../components/embedders/gradient/__init__.py | 7 - .../gradient/gradient_document_embedder.py | 174 ------ .../gradient/gradient_text_embedder.py | 113 ---- .../components/generators/gradient/base.py | 144 ----- integrations/gradient/tests/__init__.py | 3 - .../tests/test_gradient_document_embedder.py | 162 ----- .../tests/test_gradient_rag_pipelines.py | 90 --- .../tests/test_gradient_text_embedder.py | 124 ---- .../instructor_embedders/pydoc/config.yml | 2 +- .../instructor_embedders/pyproject.toml | 21 +- integrations/jina/pydoc/config.yml | 4 +- integrations/jina/pyproject.toml | 14 +- integrations/langfuse/CHANGELOG.md | 19 + integrations/langfuse/LICENSE.txt | 73 +++ integrations/langfuse/README.md | 117 ++++ integrations/langfuse/example/basic_rag.py | 65 ++ integrations/langfuse/example/chat.py | 27 + .../langfuse/example/requirements.txt | 3 + .../{gradient => langfuse}/pydoc/config.yml | 17 +- .../{gradient => langfuse}/pyproject.toml | 130 ++-- .../connectors/langfuse}/__init__.py | 4 +- .../connectors/langfuse/langfuse_connector.py | 116 ++++ .../tracing/langfuse/__init__.py | 6 + .../tracing/langfuse/tracer.py | 174 ++++++ integrations/langfuse/tests/__init__.py | 3 + .../langfuse/tests/test_langfuse_span.py | 65 ++ integrations/langfuse/tests/test_tracer.py | 114 ++++ integrations/langfuse/tests/test_tracing.py | 55 ++ integrations/llama_cpp/CHANGELOG.md | 50 ++ integrations/llama_cpp/pydoc/config.yml | 2 +- integrations/llama_cpp/pyproject.toml | 76 +-- .../generators/llama_cpp/__init__.py | 3 +- .../llama_cpp/chat/chat_generator.py | 139 +++++ .../llama_cpp/tests/test_chat_generator.py | 498 +++++++++++++++ .../examples/streaming_chat_with_rag.py | 6 +- integrations/mistral/pydoc/config.yml | 2 +- integrations/mistral/pyproject.toml | 14 +- integrations/mongodb_atlas/CHANGELOG.md | 64 ++ integrations/mongodb_atlas/pydoc/config.yml | 2 +- integrations/mongodb_atlas/pyproject.toml | 73 +-- .../mongodb_atlas/embedding_retriever.py | 20 +- .../mongodb_atlas/document_store.py | 32 +- .../tests/test_document_store.py | 13 +- .../mongodb_atlas/tests/test_retriever.py | 77 +++ integrations/nvidia/CHANGELOG.md | 42 ++ integrations/nvidia/pydoc/config.yml | 4 +- integrations/nvidia/pyproject.toml | 9 +- .../components/embedders/nvidia/__init__.py | 6 +- .../embedders/nvidia/_nim_backend.py | 46 -- .../embedders/nvidia/_nvcf_backend.py | 109 ---- .../components/embedders/nvidia/backend.py | 29 - .../embedders/nvidia/document_embedder.py | 81 ++- .../embedders/nvidia/text_embedder.py | 81 ++- .../components/embedders/nvidia/truncate.py | 32 + .../generators/nvidia/_nim_backend.py | 69 --- .../generators/nvidia/_nvcf_backend.py | 117 ---- .../components/generators/nvidia/_schema.py | 69 --- .../components/generators/nvidia/backend.py | 29 - .../components/generators/nvidia/generator.py | 87 ++- .../utils/nvidia/__init__.py | 5 +- .../utils/nvidia/client.py | 82 --- .../utils/nvidia/nim_backend.py | 131 ++++ .../utils/nvidia/utils.py | 47 ++ integrations/nvidia/tests/__init__.py | 3 + integrations/nvidia/tests/conftest.py | 44 ++ integrations/nvidia/tests/test_base_url.py | 64 ++ .../nvidia/tests/test_document_embedder.py | 249 ++++---- integrations/nvidia/tests/test_generator.py | 170 ++--- .../nvidia/tests/test_text_embedder.py | 152 +++-- integrations/ollama/CHANGELOG.md | 50 ++ integrations/ollama/pydoc/config.yml | 2 +- integrations/ollama/pyproject.toml | 69 +-- .../generators/ollama/chat/chat_generator.py | 61 +- .../ollama/tests/test_chat_generator.py | 28 +- integrations/opensearch/CHANGELOG.md | 105 ++++ integrations/opensearch/pydoc/config.yml | 2 +- integrations/opensearch/pyproject.toml | 79 +-- .../retrievers/opensearch/bm25_retriever.py | 157 ++++- .../opensearch/embedding_retriever.py | 149 ++++- .../document_stores/opensearch/auth.py | 154 +++++ .../opensearch/document_store.py | 327 +++++++--- integrations/opensearch/tests/test_auth.py | 113 ++++ .../opensearch/tests/test_bm25_retriever.py | 73 ++- .../opensearch/tests/test_document_store.py | 583 +++++++++++++++++- .../tests/test_embedding_retriever.py | 89 ++- integrations/optimum/CHANGELOG.md | 37 ++ integrations/optimum/pydoc/config.yml | 2 +- integrations/optimum/pyproject.toml | 16 +- .../components/embedders/optimum/_backend.py | 2 +- integrations/pgvector/CHANGELOG.md | 51 ++ .../{example.py => embedding_retrieval.py} | 0 .../pgvector/examples/hybrid_retrieval.py | 69 +++ integrations/pgvector/pydoc/config.yml | 3 +- integrations/pgvector/pyproject.toml | 72 +-- .../retrievers/pgvector/__init__.py | 3 +- .../pgvector/embedding_retriever.py | 23 +- .../retrievers/pgvector/keyword_retriever.py | 137 ++++ .../pgvector/document_store.py | 175 +++++- .../document_stores/pgvector/filters.py | 28 +- integrations/pgvector/tests/conftest.py | 5 +- .../pgvector/tests/test_document_store.py | 24 +- .../pgvector/tests/test_keyword_retrieval.py | 50 ++ integrations/pgvector/tests/test_retriever.py | 116 ---- .../pgvector/tests/test_retrievers.py | 318 ++++++++++ integrations/pinecone/CHANGELOG.md | 54 ++ integrations/pinecone/examples/example.py | 2 +- integrations/pinecone/pydoc/config.yml | 2 +- integrations/pinecone/pyproject.toml | 18 +- .../pinecone/embedding_retriever.py | 34 +- .../pinecone/document_store.py | 101 ++- integrations/pinecone/tests/conftest.py | 22 +- .../pinecone/tests/test_document_store.py | 127 +++- ...triever.py => test_embedding_retriever.py} | 78 ++- integrations/pinecone/tests/test_filters.py | 4 - integrations/qdrant/CHANGELOG.md | 145 +++++ integrations/qdrant/pydoc/config.yml | 2 +- integrations/qdrant/pyproject.toml | 19 +- .../components/retrievers/qdrant/retriever.py | 134 +++- .../document_stores/qdrant/converters.py | 5 +- .../document_stores/qdrant/document_store.py | 443 ++++++++++--- .../document_stores/qdrant/filters.py | 213 +++++-- .../qdrant/tests/test_dict_converters.py | 18 +- .../qdrant/tests/test_document_store.py | 29 +- integrations/qdrant/tests/test_filters.py | 122 ++++ integrations/qdrant/tests/test_retriever.py | 161 ++++- integrations/ragas/CHANGELOG.md | 47 ++ integrations/ragas/pydoc/config.yml | 2 +- integrations/ragas/pyproject.toml | 16 +- .../components/evaluators/ragas/metrics.py | 10 - integrations/ragas/tests/test_evaluator.py | 12 +- integrations/unstructured/CHANGELOG.md | 67 ++ integrations/unstructured/pydoc/config.yml | 2 +- integrations/unstructured/pyproject.toml | 83 +-- integrations/weaviate/CHANGELOG.md | 67 ++ integrations/weaviate/pydoc/config.yml | 2 +- integrations/weaviate/pyproject.toml | 26 +- .../retrievers/weaviate/__init__.py | 4 + .../retrievers/weaviate/bm25_retriever.py | 27 +- .../weaviate/embedding_retriever.py | 28 +- .../document_stores/weaviate/_filters.py | 4 + .../document_stores/weaviate/auth.py | 4 + .../weaviate/document_store.py | 138 +++-- integrations/weaviate/tests/conftest.py | 4 + integrations/weaviate/tests/test_auth.py | 4 + .../weaviate/tests/test_bm25_retriever.py | 54 ++ .../weaviate/tests/test_document_store.py | 70 ++- .../tests/test_embedding_retriever.py | 57 ++ integrations/weaviate/tests/test_filters.py | 4 + show_unreleased.sh | 9 + 278 files changed, 11160 insertions(+), 5047 deletions(-) create mode 100644 .github/pull_request_template.md delete mode 100644 .github/workflows/gradient.yml create mode 100644 .github/workflows/langfuse.yml create mode 100644 cliff.toml create mode 100644 integrations/amazon_bedrock/CHANGELOG.md mode change 100644 => 100755 integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py mode change 100644 => 100755 integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py create mode 100644 integrations/amazon_sagemaker/CHANGELOG.md create mode 100644 integrations/anthropic/CHANGELOG.md create mode 100644 integrations/astra/CHANGELOG.md create mode 100644 integrations/chroma/CHANGELOG.md create mode 100644 integrations/cohere/CHANGELOG.md create mode 100644 integrations/cohere/examples/cohere_embedding.py create mode 100644 integrations/cohere/examples/cohere_generation.py rename integrations/cohere/examples/{cohere_ranker_in_a_pipeline.py => cohere_ranker.py} (81%) rename integrations/cohere/tests/{test_cohere_generators.py => test_cohere_generator.py} (86%) create mode 100644 integrations/elasticsearch/CHANGELOG.md create mode 100644 integrations/fastembed/CHANGELOG.md create mode 100644 integrations/google_ai/CHANGELOG.md delete mode 100644 integrations/gradient/LICENSE.txt delete mode 100644 integrations/gradient/README.md delete mode 100644 integrations/gradient/src/haystack_integrations/components/embedders/gradient/__init__.py delete mode 100644 integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py delete mode 100644 integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py delete mode 100644 integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py delete mode 100644 integrations/gradient/tests/__init__.py delete mode 100644 integrations/gradient/tests/test_gradient_document_embedder.py delete mode 100644 integrations/gradient/tests/test_gradient_rag_pipelines.py delete mode 100644 integrations/gradient/tests/test_gradient_text_embedder.py create mode 100644 integrations/langfuse/CHANGELOG.md create mode 100644 integrations/langfuse/LICENSE.txt create mode 100644 integrations/langfuse/README.md create mode 100644 integrations/langfuse/example/basic_rag.py create mode 100644 integrations/langfuse/example/chat.py create mode 100644 integrations/langfuse/example/requirements.txt rename integrations/{gradient => langfuse}/pydoc/config.yml (55%) rename integrations/{gradient => langfuse}/pyproject.toml (57%) rename integrations/{gradient/src/haystack_integrations/components/generators/gradient => langfuse/src/haystack_integrations/components/connectors/langfuse}/__init__.py (57%) create mode 100644 integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py create mode 100644 integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py create mode 100644 integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py create mode 100644 integrations/langfuse/tests/__init__.py create mode 100644 integrations/langfuse/tests/test_langfuse_span.py create mode 100644 integrations/langfuse/tests/test_tracer.py create mode 100644 integrations/langfuse/tests/test_tracing.py create mode 100644 integrations/llama_cpp/CHANGELOG.md create mode 100644 integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py create mode 100644 integrations/llama_cpp/tests/test_chat_generator.py create mode 100644 integrations/mongodb_atlas/CHANGELOG.md create mode 100644 integrations/nvidia/CHANGELOG.md delete mode 100644 integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py delete mode 100644 integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nvcf_backend.py delete mode 100644 integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/backend.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py delete mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py delete mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nvcf_backend.py delete mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py delete mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/nvidia/backend.py delete mode 100644 integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py create mode 100644 integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py create mode 100644 integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py create mode 100644 integrations/nvidia/tests/conftest.py create mode 100644 integrations/nvidia/tests/test_base_url.py create mode 100644 integrations/ollama/CHANGELOG.md create mode 100644 integrations/opensearch/CHANGELOG.md create mode 100644 integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py create mode 100644 integrations/opensearch/tests/test_auth.py create mode 100644 integrations/optimum/CHANGELOG.md create mode 100644 integrations/pgvector/CHANGELOG.md rename integrations/pgvector/examples/{example.py => embedding_retrieval.py} (100%) create mode 100644 integrations/pgvector/examples/hybrid_retrieval.py create mode 100644 integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py create mode 100644 integrations/pgvector/tests/test_keyword_retrieval.py delete mode 100644 integrations/pgvector/tests/test_retriever.py create mode 100644 integrations/pgvector/tests/test_retrievers.py create mode 100644 integrations/pinecone/CHANGELOG.md rename integrations/pinecone/tests/{test_emebedding_retriever.py => test_embedding_retriever.py} (55%) create mode 100644 integrations/qdrant/CHANGELOG.md create mode 100644 integrations/ragas/CHANGELOG.md create mode 100644 integrations/unstructured/CHANGELOG.md create mode 100644 integrations/weaviate/CHANGELOG.md create mode 100755 show_unreleased.sh diff --git a/.github/labeler.yml b/.github/labeler.yml index cbfe6567e..d8bb71098 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -44,11 +44,6 @@ integration:google-vertex: - any-glob-to-any-file: "integrations/google_vertex/**/*" - any-glob-to-any-file: ".github/workflows/google_vertex.yml" -integration:gradient: - - changed-files: - - any-glob-to-any-file: "integrations/gradient/**/*" - - any-glob-to-any-file: ".github/workflows/gradient.yml" - integration:instructor-embedders: - changed-files: - any-glob-to-any-file: "integrations/instructor_embedders/**/*" @@ -59,6 +54,11 @@ integration:jina: - any-glob-to-any-file: "integrations/jina/**/*" - any-glob-to-any-file: ".github/workflows/jina.yml" +integration:langfuse: + - changed-files: + - any-glob-to-any-file: "integrations/langfuse/**/*" + - any-glob-to-any-file: ".github/workflows/langfuse.yml" + integration:llama_cpp: - changed-files: - any-glob-to-any-file: "integrations/llama_cpp/**/*" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..8b23b0f60 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,23 @@ +### Related Issues + +- fixes #issue-number + +### Proposed Changes: + + + + +### How did you test it? + + + +### Notes for the reviewer + + + +### Checklist + +- I have read the [contributors guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md) and the [code of conduct](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CODE_OF_CONDUCT.md) +- I have updated the related issue with new insights and changes +- I added unit tests and updated the docstrings +- I've used one of the [conventional commit types](https://www.conventionalcommits.org/en/v1.0.0/) for my PR title: `fix:`, `feat:`, `build:`, `chore:`, `ci:`, `docs:`, `style:`, `refactor:`, `perf:`, `test:`. diff --git a/.github/workflows/CI_project.yml b/.github/workflows/CI_project.yml index 8e40b3078..8e48ca832 100644 --- a/.github/workflows/CI_project.yml +++ b/.github/workflows/CI_project.yml @@ -10,7 +10,7 @@ jobs: name: Add new issues to project for triage runs-on: ubuntu-latest steps: - - uses: actions/add-to-project@v1.0.1 + - uses: actions/add-to-project@v1.0.2 with: project-url: https://github.com/orgs/deepset-ai/projects/5 github-token: ${{ secrets.GH_PROJECT_PAT }} diff --git a/.github/workflows/CI_pypi_release.yml b/.github/workflows/CI_pypi_release.yml index 57793c588..60115055b 100644 --- a/.github/workflows/CI_pypi_release.yml +++ b/.github/workflows/CI_pypi_release.yml @@ -20,6 +20,9 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + token: ${{ secrets.HAYSTACK_BOT_TOKEN }} + fetch-depth: 0 - name: Install Hatch run: pip install hatch @@ -43,3 +46,22 @@ jobs: HATCH_INDEX_USER: __token__ HATCH_INDEX_AUTH: ${{ secrets.PYPI_API_TOKEN }} run: hatch publish -y + + - name: Generate changelog + uses: orhun/git-cliff-action@v3 + env: + OUTPUT: "${{ steps.pathfinder.outputs.project_path }}/CHANGELOG.md" + with: + config: cliff.toml + args: > + --include-path "${{ steps.pathfinder.outputs.project_path }}/*" + --tag-pattern "${{ steps.pathfinder.outputs.project_path }}-v*" + + - name: Commit changelog + uses: EndBug/add-and-commit@v9 + with: + author_name: "HaystackBot" + author_email: "accounts@deepset.ai" + message: "Update the changelog" + add: ${{ steps.pathfinder.outputs.project_path }} + push: origin HEAD:main diff --git a/.github/workflows/CI_readme_sync.yml b/.github/workflows/CI_readme_sync.yml index 27f225e74..c6204a9be 100644 --- a/.github/workflows/CI_readme_sync.yml +++ b/.github/workflows/CI_readme_sync.yml @@ -4,6 +4,7 @@ on: push: tags: - "**-v[0-9].[0-9]+.[0-9]+" + workflow_dispatch: # Activate this workflow manually inputs: tag: @@ -16,8 +17,30 @@ env: TAG: ${{ inputs.tag || github.ref_name }} jobs: + get-versions: + runs-on: ubuntu-latest + outputs: + versions: ${{ steps.version_finder.outputs.versions }} + steps: + - name: Get Haystack Docs versions + id: version_finder + run: | + curl -s https://dash.readme.com/api/v1/version --header 'authorization: Basic ${{ secrets.README_API_KEY }}' > out + VERSIONS=$(jq '[ .[] | select(.version | startswith("2."))| .version ]' out) + { + echo 'versions<> "$GITHUB_OUTPUT" + sync: runs-on: ubuntu-latest + needs: get-versions + strategy: + fail-fast: false + max-parallel: 1 + matrix: + hs-docs-version: ${{ fromJSON(needs.get-versions.outputs.versions) }} steps: - name: Checkout this repo uses: actions/checkout@v4 @@ -39,7 +62,7 @@ jobs: import os project_path = os.environ["TAG"].rsplit("-", maxsplit=1)[0] with open(os.environ['GITHUB_OUTPUT'], 'a') as f: - print(f'project_path={project_path}', file=f) + print(f'project_path={project_path}', file=f) - name: Generate docs working-directory: ${{ steps.pathfinder.outputs.project_path }} @@ -48,13 +71,16 @@ jobs: # from Readme.io as we need them to associate the slug # in config files with their id. README_API_KEY: ${{ secrets.README_API_KEY }} + # The same category has a different id on different readme docs versions. + # This is the docs version on readme that we'll use to get the category id. + PYDOC_TOOLS_HAYSTACK_DOC_VERSION: ${{ matrix.hs-docs-version }} run: | hatch run docs mkdir tmp find . -name "_readme_*.md" -exec cp "{}" tmp \; ls tmp - - name: Sync API docs + - name: Sync API docs with Haystack docs version ${{ matrix.hs-docs-version }} uses: readmeio/rdme@v8 with: - rdme: docs ${{ steps.pathfinder.outputs.project_path }}/tmp --key=${{ secrets.README_API_KEY }} --version=2.0 + rdme: docs ${{ steps.pathfinder.outputs.project_path }}/tmp --key=${{ secrets.README_API_KEY }} --version=${{ matrix.hs-docs-version }} diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index 214b65c4b..2057d4bdf 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -62,29 +62,32 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs + - name: Run unit tests + run: hatch run cov-retry -m "not integration" + + # Do not authenticate on pull requests from forks - name: AWS authentication + if: github.event.pull_request.head.repo.full_name == github.repository uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 with: aws-region: ${{ env.AWS_REGION }} role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }} - - name: Run tests - id: tests - run: hatch run cov + # Do not run integration tests on pull requests from forks + - name: Run integration tests + if: github.event.pull_request.head.repo.full_name == github.repository + run: hatch run cov-retry -m "integration" - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" + hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/amazon_sagemaker.yml b/.github/workflows/amazon_sagemaker.yml index 2f9106181..ed0a571e6 100644 --- a/.github/workflows/amazon_sagemaker.yml +++ b/.github/workflows/amazon_sagemaker.yml @@ -54,25 +54,21 @@ jobs: - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/anthropic.yml b/.github/workflows/anthropic.yml index db4c4ce18..c4cdeb2d1 100644 --- a/.github/workflows/anthropic.yml +++ b/.github/workflows/anthropic.yml @@ -54,21 +54,18 @@ jobs: run: hatch run lint:all - name: Run tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index 74ac1023f..dcfc00c75 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -7,8 +7,8 @@ on: - cron: "0 0 * * *" pull_request: paths: - - 'integrations/astra/**' - - '.github/workflows/astra.yml' + - "integrations/astra/**" + - ".github/workflows/astra.yml" defaults: run: @@ -31,52 +31,48 @@ jobs: max-parallel: 1 matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10'] + python-version: ["3.9", "3.10"] steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - working-directory: . - run: git config --system core.longpaths true + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - - name: Install Hatch - run: pip install --upgrade hatch + - name: Install Hatch + run: pip install --upgrade hatch - - name: Lint - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run lint:all + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all - - name: Generate docs - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs - - name: Run tests - env: - ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }} - ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} - id: tests - run: hatch run cov + - name: Run tests + env: + ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }} + ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} + run: hatch run cov-retry - - name: Nightly - run unit tests with Haystack main branch - if: github.event_name == 'schedule' - id: nightly-haystack-main - run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" - - name: Send event to Datadog for nightly failures - if: failure() && github.event_name == 'schedule' - uses: ./.github/actions/send_failure - with: - title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/chroma.yml b/.github/workflows/chroma.yml index 616fecf3b..26b6287bd 100644 --- a/.github/workflows/chroma.yml +++ b/.github/workflows/chroma.yml @@ -7,8 +7,8 @@ on: - cron: "0 0 * * *" pull_request: paths: - - 'integrations/chroma/**' - - '.github/workflows/chroma.yml' + - "integrations/chroma/**" + - ".github/workflows/chroma.yml" defaults: run: @@ -30,49 +30,45 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.8', '3.9', '3.10'] + python-version: ["3.8", "3.9", "3.10"] steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - working-directory: . - run: git config --system core.longpaths true + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - - name: Install Hatch - run: pip install --upgrade hatch + - name: Install Hatch + run: pip install --upgrade hatch - - name: Lint - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run lint:all + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all - - name: Generate docs - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs - - name: Run tests - id: tests - run: hatch run cov + - name: Run tests + run: hatch run cov-retry - - name: Nightly - run unit tests with Haystack main branch - if: github.event_name == 'schedule' - id: nightly-haystack-main - run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - - - name: Send event to Datadog for nightly failures - if: failure() && github.event_name == 'schedule' - uses: ./.github/actions/send_failure - with: - title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index 81bf2356d..00a8ee2ed 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -55,25 +55,21 @@ jobs: - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/deepeval.yml b/.github/workflows/deepeval.yml index e987f03f0..23de1a3f4 100644 --- a/.github/workflows/deepeval.yml +++ b/.github/workflows/deepeval.yml @@ -58,22 +58,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/elasticsearch.yml b/.github/workflows/elasticsearch.yml index c0144a202..476e832b5 100644 --- a/.github/workflows/elasticsearch.yml +++ b/.github/workflows/elasticsearch.yml @@ -48,29 +48,25 @@ jobs: run: hatch run lint:all - name: Run ElasticSearch container - run: docker-compose up -d + run: docker compose up -d - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/fastembed.yml b/.github/workflows/fastembed.yml index 7a34378ee..e389bf3a4 100644 --- a/.github/workflows/fastembed.yml +++ b/.github/workflows/fastembed.yml @@ -20,9 +20,9 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.9","3.10","3.11"] - - steps: + python-version: ["3.9", "3.10", "3.11"] + + steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -39,25 +39,21 @@ jobs: - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} + Core integrations nightly tests failure: ${{ github.workflow }} api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/google_ai.yml b/.github/workflows/google_ai.yml index 5563e5633..1b4b2e496 100644 --- a/.github/workflows/google_ai.yml +++ b/.github/workflows/google_ai.yml @@ -58,22 +58,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/google_vertex.yml b/.github/workflows/google_vertex.yml index 3677102e4..78ba5694b 100644 --- a/.github/workflows/google_vertex.yml +++ b/.github/workflows/google_vertex.yml @@ -57,22 +57,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/gradient.yml b/.github/workflows/gradient.yml deleted file mode 100644 index 61133be87..000000000 --- a/.github/workflows/gradient.yml +++ /dev/null @@ -1,80 +0,0 @@ -# This workflow comes from https://github.com/ofek/hatch-mypyc -# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml -name: Test / gradient - -on: - schedule: - - cron: "0 0 * * *" - pull_request: - paths: - - 'integrations/gradient/**' - - '.github/workflows/gradient.yml' - -defaults: - run: - working-directory: integrations/gradient - -concurrency: - group: gradient-${{ github.head_ref }} - cancel-in-progress: true - -env: - PYTHONUNBUFFERED: "1" - FORCE_COLOR: "1" - GRADIENT_ACCESS_TOKEN: ${{ secrets.GRADIENT_ACCESS_TOKEN }} - GRADIENT_WORKSPACE_ID: ${{ secrets.GRADIENT_WORKSPACE_ID }} - -jobs: - run: - name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10'] - - steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - working-directory: . - run: git config --system core.longpaths true - - - uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install Hatch - run: pip install --upgrade hatch - - - name: Lint - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run lint:all - - - name: Generate docs - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs - - - name: Run tests - id: tests - run: hatch run cov - - - name: Nightly - run unit tests with Haystack main branch - if: github.event_name == 'schedule' - id: nightly-haystack-main - run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - - - name: Send event to Datadog for nightly failures - if: failure() && github.event_name == 'schedule' - uses: ./.github/actions/send_failure - with: - title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file diff --git a/.github/workflows/instructor_embedders.yml b/.github/workflows/instructor_embedders.yml index 70b604eaa..f12f4d696 100644 --- a/.github/workflows/instructor_embedders.yml +++ b/.github/workflows/instructor_embedders.yml @@ -35,22 +35,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/jina.yml b/.github/workflows/jina.yml index 69ef6b294..00af6eb45 100644 --- a/.github/workflows/jina.yml +++ b/.github/workflows/jina.yml @@ -57,22 +57,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/langfuse.yml b/.github/workflows/langfuse.yml new file mode 100644 index 000000000..8a10cf241 --- /dev/null +++ b/.github/workflows/langfuse.yml @@ -0,0 +1,77 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / langfuse + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/langfuse/**" + - ".github/workflows/langfuse.yml" + +defaults: + run: + working-directory: integrations/langfuse + +concurrency: + group: langfuse-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + LANGFUSE_SECRET_KEY: ${{ secrets.LANGFUSE_SECRET_KEY }} + LANGFUSE_PUBLIC_KEY: ${{ secrets.LANGFUSE_PUBLIC_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov-retry + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/llama_cpp.yml b/.github/workflows/llama_cpp.yml index 5d906a251..a9480ca96 100644 --- a/.github/workflows/llama_cpp.yml +++ b/.github/workflows/llama_cpp.yml @@ -7,8 +7,8 @@ on: - cron: "0 0 * * *" pull_request: paths: - - 'integrations/llama_cpp/**' - - '.github/workflows/llama_cpp.yml' + - "integrations/llama_cpp/**" + - ".github/workflows/llama_cpp.yml" defaults: run: @@ -30,49 +30,45 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10'] + python-version: ["3.9", "3.10"] steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - working-directory: . - run: git config --system core.longpaths true + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - - name: Install Hatch - run: pip install --upgrade hatch + - name: Install Hatch + run: pip install --upgrade hatch - - name: Lint - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run lint:all + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all - - name: Generate docs - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs - - name: Run tests - id: tests - run: hatch run cov + - name: Run tests + run: hatch run cov-retry - - name: Nightly - run unit tests with Haystack main branch - if: github.event_name == 'schedule' - id: nightly-haystack-main - run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - - - name: Send event to Datadog for nightly failures - if: failure() && github.event_name == 'schedule' - uses: ./.github/actions/send_failure - with: - title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/mistral.yml b/.github/workflows/mistral.yml index 8348f1d8f..e62008906 100644 --- a/.github/workflows/mistral.yml +++ b/.github/workflows/mistral.yml @@ -52,28 +52,24 @@ jobs: - name: Lint if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all - + - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/mongodb_atlas.yml b/.github/workflows/mongodb_atlas.yml index 94a540719..3d1ad5101 100644 --- a/.github/workflows/mongodb_atlas.yml +++ b/.github/workflows/mongodb_atlas.yml @@ -31,7 +31,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10', '3.11'] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 @@ -54,22 +54,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/nvidia.yml b/.github/workflows/nvidia.yml index 316e509be..0d39a4d91 100644 --- a/.github/workflows/nvidia.yml +++ b/.github/workflows/nvidia.yml @@ -22,6 +22,7 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" NVIDIA_API_KEY: ${{ secrets.NVIDIA_API_KEY }} + NVIDIA_CATALOG_API_KEY: ${{ secrets.NVIDIA_CATALOG_API_KEY }} jobs: run: @@ -54,8 +55,7 @@ jobs: run: hatch run lint:all - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' @@ -63,17 +63,14 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" + hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} + Core integrations nightly tests failure: ${{ github.workflow }} api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/ollama.yml b/.github/workflows/ollama.yml index 631c155eb..43af485b7 100644 --- a/.github/workflows/ollama.yml +++ b/.github/workflows/ollama.yml @@ -32,16 +32,16 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.9","3.10","3.11"] - - steps: + python-version: ["3.9", "3.10", "3.11"] + + steps: - uses: actions/checkout@v4 - name: Install Ollama and pull the required models run: | curl -fsSL https://ollama.com/install.sh | sh ollama serve & - + # Check if the service is up and running with a timeout of 60 seconds timeout=60 while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:11434/ > /dev/null; do @@ -54,7 +54,7 @@ jobs: echo "Timed out waiting for Ollama service to start." exit 1 fi - + ollama pull ${{ env.LLM_FOR_TESTS }} ollama pull ${{ env.EMBEDDER_FOR_TESTS }} @@ -75,22 +75,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/opensearch.yml b/.github/workflows/opensearch.yml index ed7967d79..48169a75f 100644 --- a/.github/workflows/opensearch.yml +++ b/.github/workflows/opensearch.yml @@ -20,7 +20,7 @@ env: defaults: run: - working-directory: integrations/opensearch + working-directory: integrations/opensearch jobs: run: @@ -48,29 +48,25 @@ jobs: run: hatch run lint:all - name: Run opensearch container - run: docker-compose up -d + run: docker compose up -d - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/optimum.yml b/.github/workflows/optimum.yml index c6ae9a0ee..c33baa7f8 100644 --- a/.github/workflows/optimum.yml +++ b/.github/workflows/optimum.yml @@ -57,22 +57,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml index 288da8e9d..0fe20e037 100644 --- a/.github/workflows/pgvector.yml +++ b/.github/workflows/pgvector.yml @@ -20,7 +20,7 @@ env: defaults: run: - working-directory: integrations/pgvector + working-directory: integrations/pgvector jobs: run: @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.9","3.10","3.11"] + python-version: ["3.9", "3.10", "3.11"] services: pgvector: image: ankane/pgvector:latest @@ -40,8 +40,8 @@ jobs: POSTGRES_DB: postgres ports: - 5432:5432 - - steps: + + steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -58,25 +58,21 @@ jobs: - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs + run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/pinecone.yml b/.github/workflows/pinecone.yml index 85f36ae5d..9e143005b 100644 --- a/.github/workflows/pinecone.yml +++ b/.github/workflows/pinecone.yml @@ -33,6 +33,12 @@ jobs: # Pinecone tests are time expensive, so the matrix is limited to Python 3.9 and 3.10 os: [ubuntu-latest] python-version: ["3.9", "3.10"] + # the INDEX_NAME is used in test_serverless_index_creation_from_scratch + include: + - python-version: "3.9" + INDEX_NAME: "index-39" + - python-version: "3.10" + INDEX_NAME: "index-310" steps: - uses: actions/checkout@v4 @@ -55,22 +61,20 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + env: + INDEX_NAME: ${{ matrix.INDEX_NAME }} + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/qdrant.yml b/.github/workflows/qdrant.yml index 5995911fb..116225b2d 100644 --- a/.github/workflows/qdrant.yml +++ b/.github/workflows/qdrant.yml @@ -57,22 +57,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/ragas.yml b/.github/workflows/ragas.yml index af67484ad..c4757e704 100644 --- a/.github/workflows/ragas.yml +++ b/.github/workflows/ragas.yml @@ -21,7 +21,7 @@ concurrency: env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} jobs: run: @@ -58,22 +58,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/unstructured.yml b/.github/workflows/unstructured.yml index a5943c5b5..e4b640275 100644 --- a/.github/workflows/unstructured.yml +++ b/.github/workflows/unstructured.yml @@ -8,11 +8,11 @@ on: pull_request: paths: - "integrations/unstructured/**" - - ".github/workflows/unstructured.yml" + - ".github/workflows/unstructured.yml" defaults: run: - working-directory: integrations/unstructured + working-directory: integrations/unstructured concurrency: group: unstructured-${{ github.head_ref }} @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] steps: - name: Free up disk space @@ -48,7 +48,7 @@ jobs: --health-interval 10s \ --health-timeout 1s \ --health-retries 10 \ - quay.io/unstructured-io/unstructured-api:latest + quay.io/unstructured-io/unstructured-api:latest - uses: actions/checkout@v4 @@ -69,22 +69,18 @@ jobs: run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/weaviate.yml b/.github/workflows/weaviate.yml index 2588d4113..5e29eafe7 100644 --- a/.github/workflows/weaviate.yml +++ b/.github/workflows/weaviate.yml @@ -47,29 +47,25 @@ jobs: run: hatch run lint:all - name: Run Weaviate container - run: docker-compose up -d + run: docker compose up -d - name: Generate docs if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs - name: Run tests - id: tests - run: hatch run cov + run: hatch run cov-retry - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' - id: nightly-haystack-main run: | hatch run pip install git+https://github.com/deepset-ai/haystack.git - hatch run test -m "not integration" - + hatch run cov-retry -m "not integration" + - name: Send event to Datadog for nightly failures if: failure() && github.event_name == 'schedule' uses: ./.github/actions/send_failure with: title: | - core-integrations failure: - ${{ (steps.tests.conclusion == 'nightly-haystack-main') && 'nightly-haystack-main' || 'tests' }} - - ${{ github.workflow }} + Core integrations nightly tests failure: ${{ github.workflow }} api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.gitignore b/.gitignore index c3a7cf863..890c704fe 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,12 @@ dmypy.json # Docs generation artifacts _readme_*.md .idea + +# macOS +.DS_Store + +# http cache (requests-cache) +**/http_cache.sqlite + +# ruff +.ruff_cache \ No newline at end of file diff --git a/README.md b/README.md index 734672371..010ca1763 100644 --- a/README.md +++ b/README.md @@ -23,21 +23,21 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta ## Inventory | Package | Type | PyPi Package | Status | -| -------------------------------------------------------------------------------------------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [amazon-bedrock-haystack](integrations/amazon-bedrock/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | +|----------------------------------------------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [amazon-bedrock-haystack](integrations/amazon_bedrock/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | | [amazon-sagemaker-haystack](integrations/amazon_sagemaker/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) | [![Test / amazon_sagemaker](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml) | | [anthropic-haystack](integrations/anthropic/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/anthropic-haystack.svg)](https://pypi.org/project/anthropic-haystack) | [![Test / anthropic](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/anthropic.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/anthropic.yml) | | [astra-haystack](integrations/astra/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/astra-haystack.svg)](https://pypi.org/project/astra-haystack) | [![Test / astra](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml) | | [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | -| [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | +| [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | | [fastembed-haystack](integrations/fastembed/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | | [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | | [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | -| [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | | [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | -| [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [jina-haystack](integrations/jina/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [langfuse-haystack](integrations/langfuse/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/langfuse-haystack.svg?color=orange)](https://pypi.org/project/langfuse-haystack) | [![Test / langfuse](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml) | | [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/llama-cpp-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | | [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | | [mongodb-atlas-haystack](integrations/mongodb_atlas/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg?color=orange)](https://pypi.org/project/mongodb-atlas-haystack) | [![Test / mongodb-atlas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml) | @@ -50,5 +50,35 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [ragas-haystack](integrations/ragas/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/ragas-haystack.svg)](https://pypi.org/project/ragas-haystack) | [![Test / ragas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | -| [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | Staged | +| [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [Staged](https://docs.haystack.deepset.ai/docs/breaking-change-policy#discontinuing-an-integration) | | [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | + +## Releasing + +> [!NOTE] +> Only maintainers can release new versions of integrations. +> If you're a community contributor and want to release a new version of an integration, +> reach out to a maintainer. + +To release a new version of an integration to PyPI tag the commit with the right version number and push the tag to +GitHub. The GitHub Actions workflow will take care of the rest. + +1. Tag the commit with the right version number + + The tag needs to have the following format: + + ``` + git tag integrations/- + ``` + + For example, if we want to release version 1.0.99 of the google-vertex-haystack integration we'd have to push the tag: + + ``` + git tag integrations/google_vertex-v1.0.99 + ``` +2. Push the tag to GitHub + + ``` + git push --tags origin + ``` +3. Wait for the CI to do its magic diff --git a/cliff.toml b/cliff.toml new file mode 100644 index 000000000..29228543e --- /dev/null +++ b/cliff.toml @@ -0,0 +1,84 @@ +# git-cliff ~ default configuration file +# https://git-cliff.org/docs/configuration +# +# Lines starting with "#" are comments. +# Configuration options are organized into tables and keys. +# See documentation for more information on available options. + +[changelog] +# changelog header +header = """ +# Changelog\n +""" +# template for the changelog body +# https://keats.github.io/tera/docs/#introduction +body = """ +{% if version %}\ + ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} +{% else %}\ + ## [unreleased] +{% endif %}\ +{% for group, commits in commits | group_by(attribute="group") %} + ### {{ group | striptags | trim | upper_first }} + {% for commit in commits %} + - {% if commit.scope %}*({{ commit.scope }})* {% endif %}\ + {% if commit.breaking %}[**breaking**] {% endif %}\ + {{ commit.message | upper_first }}\ + {% endfor %} +{% endfor %}\n +""" +# template for the changelog footer +footer = """ + +""" +# remove the leading and trailing s +trim = true +# postprocessors +postprocessors = [ + # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL +] + +[git] +# parse the commits based on https://www.conventionalcommits.org +conventional_commits = true +# filter out the commits that are not conventional +filter_unconventional = false +# process each line of a commit as an individual commit +split_commits = false +# regex for preprocessing the commit messages +commit_preprocessors = [ + # Replace issue numbers + #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, + # Check spelling of the commit with https://github.com/crate-ci/typos + # If the spelling is incorrect, it will be automatically fixed. + #{ pattern = '.*', replace_command = 'typos --write-changes -' }, +] +# regex for parsing and grouping commits +commit_parsers = [ + { message = "^feat", group = "🚀 Features" }, + { message = "^fix", group = "🐛 Bug Fixes" }, + { message = "^doc", group = "📚 Documentation" }, + { message = "^perf", group = "⚡ Performance" }, + { message = "^refactor", group = "🚜 Refactor" }, + { message = "^style", group = "🎨 Styling" }, + { message = "^test", group = "🧪 Testing" }, + { message = "^chore|^ci", group = "⚙️ Miscellaneous Tasks" }, + { body = ".*security", group = "🛡️ Security" }, + { message = "^revert", group = "◀️ Revert" }, +] +# protect breaking changes from being skipped due to matching a skipping commit_parser +protect_breaking_commits = false +# filter out the commits that are not matched by commit parsers +filter_commits = false +# regex for matching git tags +# tag_pattern = "v[0-9].*" +# regex for skipping tags +# skip_tags = "" +# regex for ignoring tags +# ignore_tags = "" +# sort the tags topologically +topo_order = false +# sort the commits inside sections by oldest/newest order +sort_commits = "oldest" +# limit the number of commits included in the changelog. +# limit_commits = 42 diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md new file mode 100644 index 000000000..d347c08d9 --- /dev/null +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -0,0 +1,106 @@ +# Changelog + +## [integrations/amazon_bedrock-v1.0.0] - 2024-08-12 + +### 🚜 Refactor + +- Change meta data fields (#911) + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) + +## [integrations/amazon_bedrock-v0.10.0] - 2024-08-12 + +### 🐛 Bug Fixes + +- Support streaming_callback param in amazon bedrock generators (#927) + +### Docs + +- Update AmazonBedrockChatGenerator docstrings (#949) +- Update AmazonBedrockGenerator docstrings (#956) + +## [integrations/amazon_bedrock-v0.9.3] - 2024-07-17 + +### 🚀 Features + +- Use non-gated tokenizer as fallback for mistral in AmazonBedrockChatGenerator (#843) +- Made truncation optional for BedrockGenerator (#833) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/amazon_bedrock-v0.9.0] - 2024-06-14 + +### 🚀 Features + +- Support Claude v3, Llama3 and Command R models on Amazon Bedrock (#809) + +### 🧪 Testing + +- Amazon Bedrock - skip integration tests from forks (#801) + +## [integrations/amazon_bedrock-v0.8.0] - 2024-05-23 + +### 🐛 Bug Fixes + +- Max_tokens typo in Mistral Chat (#740) + +## [integrations/amazon_bedrock-v0.7.1] - 2024-04-24 + +## [integrations/amazon_bedrock-v0.7.0] - 2024-04-16 + +### 🚀 Features + +- Add Mistral Amazon Bedrock support (#632) + +### 📚 Documentation + +- Disable-class-def (#556) + +## [integrations/amazon_bedrock-v0.6.0] - 2024-03-11 + +### 🚀 Features + +- AmazonBedrockChatGenerator - migrate Anthropic chat models to use messaging API (#545) + +### 📚 Documentation + +- Small consistency improvements (#536) +- Review integrations bedrock (#550) + +## [integrations/amazon_bedrock-v0.5.1] - 2024-02-22 + +### 🚀 Features + +- Add Amazon Bedrock chat model support (#333) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +### ⚙️ Miscellaneous Tasks + +- Update Amazon Bedrock integration to use new generic callable (de)serializers for their callback handlers (#452) +- Use `serialize_callable` instead of `serialize_callback_handler` in Bedrock (#459) + +## [integrations/amazon_bedrock-v0.3.0] - 2024-01-30 + +### ⚙️ Miscellaneous Tasks + +- [**breaking**] Rename `model_name` to `model` in `AmazonBedrockGenerator` (#220) +- Amazon Bedrock subproject refactoring (#293) +- Adjust amazon bedrock helper classes names (#297) + +## [integrations/amazon_bedrock-v0.1.0] - 2024-01-03 + + diff --git a/integrations/amazon_bedrock/pydoc/config.yml b/integrations/amazon_bedrock/pydoc/config.yml index c719c7cfd..6cb05d6f3 100644 --- a/integrations/amazon_bedrock/pydoc/config.yml +++ b/integrations/amazon_bedrock/pydoc/config.yml @@ -20,7 +20,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Amazon Bedrock integration for Haystack category_slug: integrations-api title: Amazon Bedrock diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index bbdc61484..f4a410dbd 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -25,11 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "boto3>=1.28.57", - "transformers" -] +dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme" @@ -51,50 +45,31 @@ git_describe_command = 'git describe --tags --match="integrations/amazon_bedrock dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] -all = [ - "style", - "typing", -] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -134,11 +109,19 @@ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", # Ignore unused params - "ARG001", "ARG002", "ARG005", + "ARG001", + "ARG002", + "ARG005", ] unfixable = [ # Don't touch unused imports @@ -162,12 +145,8 @@ parallel = true [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py old mode 100644 new mode 100755 index a5621cbd2..1b8fde124 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py @@ -16,7 +16,12 @@ logger = logging.getLogger(__name__) -SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] +SUPPORTED_EMBEDDING_MODELS = [ + "amazon.titan-embed-text-v1", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", + "amazon.titan-embed-text-v2:0", +] @component @@ -51,7 +56,12 @@ class AmazonBedrockDocumentEmbedder: def __init__( self, - model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"], + model: Literal[ + "amazon.titan-embed-text-v1", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", + "amazon.titan-embed-text-v2:0", + ], aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 "AWS_SECRET_ACCESS_KEY", strict=False diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py old mode 100644 new mode 100755 index 91a9e3b72..0cceda92f --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py @@ -14,7 +14,12 @@ logger = logging.getLogger(__name__) -SUPPORTED_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] +SUPPORTED_EMBEDDING_MODELS = [ + "amazon.titan-embed-text-v1", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", + "amazon.titan-embed-text-v2:0", +] @component @@ -44,7 +49,12 @@ class AmazonBedrockTextEmbedder: def __init__( self, - model: Literal["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"], + model: Literal[ + "amazon.titan-embed-text-v1", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", + "amazon.titan-embed-text-v2:0", + ], aws_access_key_id: Optional[Secret] = Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False), # noqa: B008 aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 "AWS_SECRET_ACCESS_KEY", strict=False diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index f5bd4aa07..8b5c2b530 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -1,8 +1,8 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional -from .handlers import TokenStreamingHandler +from haystack.dataclasses import StreamingChunk class BedrockModelAdapter(ABC): @@ -39,22 +39,24 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]: responses = [completion.lstrip() for completion in completions] return responses - def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]: + def get_stream_responses(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: """ Extracts the responses from the Amazon Bedrock streaming response. :param stream: The streaming response from the Amazon Bedrock request. - :param stream_handler: The handler for the streaming response. + :param streaming_callback: The handler for the streaming response. :returns: A list of string responses. """ - tokens: List[str] = [] + streaming_chunks: List[StreamingChunk] = [] for event in stream: chunk = event.get("chunk") if chunk: decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(decoded_chunk) - tokens.append(stream_handler(token, event_data=decoded_chunk)) - responses = ["".join(tokens).lstrip()] + streaming_chunk: StreamingChunk = self._build_streaming_chunk(decoded_chunk) + streaming_chunks.append(streaming_chunk) + streaming_callback(streaming_chunk) + + responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()] return responses def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: @@ -84,12 +86,12 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L """ @abstractmethod - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ @@ -98,6 +100,10 @@ class AnthropicClaudeAdapter(BedrockModelAdapter): Adapter for the Anthropic Claude models. """ + def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None: + self.use_messages_api = model_kwargs.get("use_messages_api", True) + super().__init__(model_kwargs, max_length) + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: """ Prepares the body for the Claude model @@ -108,16 +114,30 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - `prompt`: The prompt to be sent to the model. - specified inference parameters. """ - default_params = { - "max_tokens_to_sample": self.max_length, - "stop_sequences": ["\n\nHuman:"], - "temperature": None, - "top_p": None, - "top_k": None, - } - params = self._get_params(inference_kwargs, default_params) - - body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} + if self.use_messages_api: + default_params: Dict[str, Any] = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self.max_length, + "system": None, + "stop_sequences": None, + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"messages": [{"role": "user", "content": prompt}], **params} + else: + default_params = { + "max_tokens_to_sample": self.max_length, + "stop_sequences": ["\n\nHuman:"], + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: @@ -127,16 +147,22 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L :param response_body: The response body from the Amazon Bedrock request. :returns: A list of string responses. """ + if self.use_messages_api: + return [content["text"] for content in response_body["content"]] + return [response_body["completion"]] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("completion", "") + if self.use_messages_api: + return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) + + return StreamingChunk(content=chunk.get("completion", ""), meta=chunk) class MistralAdapter(BedrockModelAdapter): @@ -175,17 +201,18 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L """ return [output.get("text", "") for output in response_body.get("outputs", [])] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ + content = "" chunk_list = chunk.get("outputs", []) if chunk_list: - return chunk_list[0].get("text", "") - return "" + content = chunk_list[0].get("text", "") + return StreamingChunk(content=content, meta=chunk) class CohereCommandAdapter(BedrockModelAdapter): @@ -230,14 +257,74 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [generation["text"] for generation in response_body["generations"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: + """ + Extracts the content and meta from a streaming chunk. + + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. + """ + return StreamingChunk(content=chunk.get("text", ""), meta=chunk) + + +class CohereCommandRAdapter(BedrockModelAdapter): + """ + Adapter for the Cohere Command R models. + """ + + def prepare_body(self, prompt: str, **inference_kwargs: Any) -> Dict[str, Any]: + """ + Prepares the body for the Command model + + :param prompt: The prompt to be sent to the model. + :param inference_kwargs: Additional keyword arguments passed to the handler. + :returns: A dictionary with the following keys: + - `prompt`: The prompt to be sent to the model. + - specified inference parameters. + """ + default_params = { + "chat_history": None, + "documents": None, + "search_query_only": None, + "preamble": None, + "max_tokens": self.max_length, + "temperature": None, + "p": None, + "k": None, + "prompt_truncation": None, + "frequency_penalty": None, + "presence_penalty": None, + "seed": None, + "return_prompt": None, + "tools": None, + "tool_results": None, + "stop_sequences": None, + "raw_prompting": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"message": prompt, **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """ + Extracts the responses from the Cohere Command model response. + + :param response_body: The response body from the Amazon Bedrock request. + :returns: A list of string responses. + """ + responses = [response_body["text"]] + return responses + + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("text", "") + token: str = chunk.get("text", "") + return StreamingChunk(content=token, meta=chunk) class AI21LabsJurassic2Adapter(BedrockModelAdapter): @@ -273,7 +360,7 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [completion["data"]["text"] for completion in response_body["completions"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: msg = "Streaming is not supported for AI21 Jurassic 2 models." raise NotImplementedError(msg) @@ -314,17 +401,17 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L responses = [result["outputText"] for result in response_body["results"]] return responses - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("outputText", "") + return StreamingChunk(content=chunk.get("outputText", ""), meta=chunk) -class MetaLlama2ChatAdapter(BedrockModelAdapter): +class MetaLlamaAdapter(BedrockModelAdapter): """ Adapter for Meta's Llama2 models. """ @@ -358,11 +445,11 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L """ return [response_body["generation"]] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: A string token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("generation", "") + return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index f0a2ea368..1f0430810 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -1,7 +1,8 @@ import json import logging +import os from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List +from typing import Any, Callable, ClassVar, Dict, List, Optional from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -20,11 +21,12 @@ class BedrockModelChatAdapter(ABC): focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. """ - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: """ - Initializes the chat adapter with the generation kwargs. + Initializes the chat adapter with the truncate parameter and generation kwargs. """ self.generation_kwargs = generation_kwargs + self.truncate = truncate @abstractmethod def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -47,19 +49,18 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: return self._extract_messages_from_response(response_body) def get_stream_responses( - self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None] + self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: - tokens: List[str] = [] + streaming_chunks: List[StreamingChunk] = [] last_decoded_chunk: Dict[str, Any] = {} for event in stream: chunk = event.get("chunk") if chunk: last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(last_decoded_chunk) - stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only - stream_handler(stream_chunk) # callback the stream handler with StreamingChunk - tokens.append(token) - responses = ["".join(tokens).lstrip()] + streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) + streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk + streaming_chunks.append(streaming_chunk) + responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod @@ -141,12 +142,12 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List """ @abstractmethod - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ @@ -166,13 +167,14 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): "system", ] - def __init__(self, generation_kwargs: Dict[str, Any]): + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): """ Initializes the Anthropic Claude chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ - super().__init__(generation_kwargs) + super().__init__(truncate, generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed @@ -216,7 +218,7 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. + :returns: The prepared chat messages as a dictionary. """ body: Dict[str, Any] = {} system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None @@ -225,6 +227,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: ] if system: body["system"] = system + # Ensure token limit for each message in the body + if self.truncate: + for message in body["messages"]: + for content in message["content"]: + content["text"] = self._ensure_token_limit(content["text"]) return body def check_prompt(self, prompt: str) -> Dict[str, Any]: @@ -251,16 +258,16 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) return messages - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": - return chunk.get("delta", {}).get("text", "") - return "" + return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) + return StreamingChunk(content="", meta=chunk) def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: """ @@ -316,13 +323,13 @@ class MistralChatAdapter(BedrockModelChatAdapter): "top_p", ] - def __init__(self, generation_kwargs: Dict[str, Any]): + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): """ Initializes the Mistral chat adapter. - + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ - super().__init__(generation_kwargs) + super().__init__(truncate, generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed @@ -332,11 +339,23 @@ def __init__(self, generation_kwargs: Dict[str, Any]): # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer # a) we should get good estimates for the prompt length # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. + tokenizer: PreTrainedTokenizer + if os.environ.get("HF_TOKEN"): + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + else: + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") + logger.warning( + "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " + "estimating the prompt length because no HF_TOKEN was found. Using " + "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " + "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." + ) + self.prompt_handler = DefaultPromptHandler( tokenizer=tokenizer, model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_gen_len") or 512, + max_length=self.generation_kwargs.get("max_tokens") or 512, ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -372,7 +391,9 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template ) - return self._ensure_token_limit(prepared_prompt) + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]: """ @@ -412,17 +433,17 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) return messages - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ response_chunk = chunk.get("outputs", []) if response_chunk: - return response_chunk[0].get("text", "") - return "" + return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) + return StreamingChunk(content="", meta=chunk) class MetaLlama2ChatAdapter(BedrockModelChatAdapter): @@ -458,12 +479,13 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): "{% endfor %}" ) - def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: """ Initializes the Meta Llama 2 chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. :param generation_kwargs: The generation kwargs. """ - super().__init__(generation_kwargs) + super().__init__(truncate, generation_kwargs) # We pop the model_max_length as it is not sent to the model # but used to truncate the prompt if needed # Llama 2 has context window size of 4096 tokens @@ -507,7 +529,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( conversation=messages, tokenize=False, chat_template=self.chat_template ) - return self._ensure_token_limit(prepared_prompt) + + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt def check_prompt(self, prompt: str) -> Dict[str, Any]: """ @@ -530,11 +555,11 @@ def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List metadata = {k: v for (k, v) in response_body.items() if k != message_tag} return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: """ - Extracts the token from a streaming chunk. + Extracts the content and meta from a streaming chunk. - :param chunk: The streaming chunk. - :returns: The extracted token. + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. """ - return chunk.get("generation", "") + return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 960968755..5fa9e0b8a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -23,11 +23,13 @@ @component class AmazonBedrockChatGenerator: """ - `AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs. + Completes chats using LLMs hosted on Amazon Bedrock. - For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the + For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the 'anthropic.claude-3-sonnet-20240229-v1:0' model name. + ### Usage example + ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack.dataclasses import ChatMessage @@ -43,8 +45,16 @@ class AmazonBedrockChatGenerator: ``` - If you prefer non-streaming mode, simply remove the `streaming_callback` parameter, capture the return value of the - component's run method and the `AmazonBedrockChatGenerator` will return the response in a non-streaming mode. + AmazonBedrockChatGenerator uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM. + For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation] + (https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html). + + If the AWS environment is configured correctly, the AWS credentials are not required as they're loaded + automatically from the environment or the AWS configuration file. + If the AWS environment is not configured, set `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name` as environment variables or pass them as + [Secret](https://docs.haystack.deepset.ai/v2.0/docs/secret-management) arguments. Make sure the region you set + supports Amazon Bedrock. """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { @@ -66,6 +76,7 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + truncate: Optional[bool] = True, ): """ Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the @@ -77,23 +88,28 @@ def __init__( constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, and `aws_region_name`. - :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to - be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + :param model: The model to use for text generation. The model must be available in Amazon Bedrock and must + be specified in the format outlined in the [Amazon Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). :param aws_access_key_id: AWS access key ID. :param aws_secret_access_key: AWS secret access key. :param aws_session_token: AWS session token. - :param aws_region_name: AWS region name. + :param aws_region_name: AWS region name. Make sure the region you set supports Amazon Bedrock. :param aws_profile_name: AWS profile name. - :param generation_kwargs: Additional generation keyword arguments passed to the model. The defined keyword - parameters are specific to a specific model and can be found in the model's documentation. For example, the - Anthropic Claude generation parameters can be found [here](https://docs.anthropic.com/claude/reference/complete_post). - :param stop_words: A list of stop words that stop model generation when encountered. They can be provided via - this parameter or via models generation_kwargs under a model's specific key for stop words. For example, the - Anthropic Claude stop words are provided via the `stop_sequences` key. - :param streaming_callback: A callback function that is called when a new chunk is received from the stream. - By default, the model is not set up for streaming. To enable streaming simply set this parameter to a callback - function that will handle the streaming chunks. The callback function will receive a StreamingChunk object and - switch the streaming mode on. + :param generation_kwargs: Keyword arguments sent to the model. These + parameters are specific to a model. You can find them in the model's documentation. + For example, you can find the + Anthropic Claude generation parameters in [Anthropic documentation](https://docs.anthropic.com/claude/reference/complete_post). + :param stop_words: A list of stop words that stop the model from generating more text + when encountered. You can provide them using + this parameter or using the model's `generation_kwargs` under a model's specific key for stop words. + For example, you can provide + stop words for Anthropic Claude in the `stop_sequences` key. + :param streaming_callback: A callback function called when a new token is received from the stream. + By default, the model is not set up for streaming. To enable streaming, set this parameter to a callback + function that handles the streaming chunks. The callback function receives a + [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and + switches the streaming mode on. + :param truncate: Whether to truncate the prompt messages or not. """ if not model: msg = "'model' cannot be None or empty string" @@ -104,13 +120,14 @@ def __init__( self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.truncate = truncate # get the model adapter for the given model model_adapter_cls = self.get_model_adapter(model=model) if not model_adapter_cls: msg = f"AmazonBedrockGenerator doesn't support the model {model}." raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(generation_kwargs or {}) + self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -135,18 +152,29 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.stop_words = stop_words or [] self.streaming_callback = streaming_callback - def invoke(self, *args, **kwargs): + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ - Invokes the Amazon Bedrock LLM with the given parameters. The parameters are passed to the Amazon Bedrock - client. + Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. - :param args: The positional arguments passed to the generator. - :param kwargs: The keyword arguments passed to the generator. - :returns: List of `ChatMessage` generated by LLM. + :param messages: The messages to generate a response to. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Additional generation keyword arguments passed to the model. + :returns: A dictionary with the following keys: + - `replies`: The generated List of `ChatMessage` objects. """ + generation_kwargs = generation_kwargs or {} + generation_kwargs = generation_kwargs.copy() + + streaming_callback = streaming_callback or self.streaming_callback + generation_kwargs["stream"] = streaming_callback is not None - kwargs = kwargs.copy() - messages: List[ChatMessage] = kwargs.pop("messages", []) # check if the prompt is a list of ChatMessage objects if not ( isinstance(messages, list) @@ -156,39 +184,35 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) + body = self.model_adapter.prepare_body( + messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} + ) try: - if self.streaming_callback: + if streaming_callback: response = self.client.invoke_model_with_response_stream( body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_stream = response["body"] - responses = self.model_adapter.get_stream_responses( - stream=response_stream, stream_handler=self.streaming_callback + replies = self.model_adapter.get_stream_responses( + stream=response_stream, streaming_callback=streaming_callback ) else: response = self.client.invoke_model( body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses(response_body=response_body) + replies = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception - return responses - - @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. + # rename the meta key to be inline with OpenAI meta output keys + for response in replies: + if response.meta is not None and "usage" in response.meta: + response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") + response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") - :param messages: The messages to generate a response to. - :param generation_kwargs: Additional generation keyword arguments passed to the model. - :returns: A dictionary with the following keys: - - `replies`: The generated List of `ChatMessage` objects. - """ - return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))} + return {"replies": replies} @classmethod def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: @@ -222,6 +246,7 @@ def to_dict(self) -> Dict[str, Any]: stop_words=self.stop_words, generation_kwargs=self.model_adapter.generation_kwargs, streaming_callback=callback_name, + truncate=self.truncate, ) @classmethod diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 81a02b749..6ef0a4765 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -1,11 +1,12 @@ import json import logging import re -from typing import Any, ClassVar, Dict, List, Optional, Type, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Type from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack.dataclasses import StreamingChunk +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -19,13 +20,12 @@ AnthropicClaudeAdapter, BedrockModelAdapter, CohereCommandAdapter, - MetaLlama2ChatAdapter, + CohereCommandRAdapter, + MetaLlamaAdapter, MistralAdapter, ) from .handlers import ( DefaultPromptHandler, - DefaultTokenStreamingHandler, - TokenStreamingHandler, ) logger = logging.getLogger(__name__) @@ -34,13 +34,14 @@ @component class AmazonBedrockGenerator: """ - `AmazonBedrockGenerator` enables text generation via Amazon Bedrock hosted LLMs. + Generates text using models hosted on Amazon Bedrock. - For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockGenerator` with the - 'anthropic.claude-v2' model name. Provide AWS credentials either via local AWS profile or directly via + For example, to use the Anthropic Claude model, pass 'anthropic.claude-v2' in the `model` parameter. + Provide AWS credentials either through the local AWS profile or directly through `aws_access_key_id`, `aws_secret_access_key`, `aws_session_token`, and `aws_region_name` parameters. - Usage example: + ### Usage example + ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator @@ -51,14 +52,25 @@ class AmazonBedrockGenerator: print(generator.run("Who is the best American actor?")) ``` + + AmazonBedrockGenerator uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM. + For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation] + (https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html). + If the AWS environment is configured correctly, the AWS credentials are not required as they're loaded + automatically from the environment or the AWS configuration file. + If the AWS environment is not configured, set `aws_access_key_id`, `aws_secret_access_key`, + `aws_session_token`, and `aws_region_name` as environment variables or pass them as + [Secret](https://docs.haystack.deepset.ai/v2.0/docs/secret-management) arguments. Make sure the region you set + supports Amazon Bedrock. """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { r"amazon.titan-text.*": AmazonTitanAdapter, r"ai21.j2.*": AI21LabsJurassic2Adapter, - r"cohere.command.*": CohereCommandAdapter, + r"cohere.command-[^r].*": CohereCommandAdapter, + r"cohere.command-r.*": CohereCommandRAdapter, r"anthropic.claude.*": AnthropicClaudeAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, + r"meta.llama.*": MetaLlamaAdapter, r"mistral.*": MistralAdapter, } @@ -73,6 +85,8 @@ def __init__( aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, + truncate: Optional[bool] = True, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, **kwargs, ): """ @@ -82,10 +96,14 @@ def __init__( :param aws_access_key_id: The AWS access key ID. :param aws_secret_access_key: The AWS secret access key. :param aws_session_token: The AWS session token. - :param aws_region_name: The AWS region name. + :param aws_region_name: The AWS region name. Make sure the region you set supports Amazon Bedrock. :param aws_profile_name: The AWS profile name. :param max_length: The maximum length of the generated text. + :param truncate: Whether to truncate the prompt or not. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. :param kwargs: Additional keyword arguments to be passed to the model. + These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is not supported. @@ -95,11 +113,14 @@ def __init__( raise ValueError(msg) self.model = model self.max_length = max_length + self.truncate = truncate self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.streaming_callback = streaming_callback + self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -127,6 +148,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # Truncate prompt if prompt tokens > model_max_length-max_length # (max_length is the length of the generated text) # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, @@ -139,7 +161,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: raise AmazonBedrockConfigurationError(msg) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) - def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: + def _ensure_token_limit(self, prompt: str) -> str: """ Ensures that the prompt and answer token lengths together are within the model_max_length specified during the initialization of the component. @@ -147,14 +169,6 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union :param prompt: The prompt to be sent to the model. :returns: The resized prompt. """ - # the prompt for this model will be of the type str - if isinstance(prompt, List): - msg = ( - "AmazonBedrockGenerator only supports a string as a prompt, " - "while currently, the prompt is of type List." - ) - raise ValueError(msg) - resize_info = self.prompt_handler(prompt) if resize_info["prompt_length"] != resize_info["new_prompt_length"]: logger.warning( @@ -168,28 +182,36 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union ) return str(resize_info["resized_prompt"]) - def invoke(self, *args, **kwargs): + @component.output_types(replies=List[str]) + def run( + self, + prompt: str, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ - Invokes the model with the given prompt. + Generates a list of string response to the given prompt. - :param args: Additional positional arguments passed to the generator. - :param kwargs: Additional keyword arguments passed to the generator. - :returns: A list of generated responses (strings). + :param prompt: The prompt to generate a response for. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Additional keyword arguments passed to the generator. + :returns: A dictionary with the following keys: + - `replies`: A list of generated responses. + :raises ValueError: If the prompt is empty or None. + :raises AmazonBedrockInferenceError: If the model cannot be invoked. """ - kwargs = kwargs.copy() - prompt: str = kwargs.pop("prompt", None) - stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False)) + generation_kwargs = generation_kwargs or {} + generation_kwargs = generation_kwargs.copy() + streaming_callback = streaming_callback or self.streaming_callback + generation_kwargs["stream"] = streaming_callback is not None - if not prompt or not isinstance(prompt, (str, list)): - msg = ( - f"The model {self.model} requires a valid prompt, but currently, it has no prompt. " - f"Make sure to provide a prompt in the format that the model expects." - ) - raise ValueError(msg) + if self.truncate: + prompt = self._ensure_token_limit(prompt) - body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) + body = self.model_adapter.prepare_body(prompt=prompt, **generation_kwargs) try: - if stream: + if streaming_callback: response = self.client.invoke_model_with_response_stream( body=json.dumps(body), modelId=self.model, @@ -197,11 +219,9 @@ def invoke(self, *args, **kwargs): contentType="application/json", ) response_stream = response["body"] - handler: TokenStreamingHandler = kwargs.get( - "stream_handler", - self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()), + replies = self.model_adapter.get_stream_responses( + stream=response_stream, streaming_callback=streaming_callback ) - responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( body=json.dumps(body), @@ -210,7 +230,7 @@ def invoke(self, *args, **kwargs): contentType="application/json", ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses(response_body=response_body) + replies = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = ( f"Could not connect to Amazon Bedrock model {self.model}. " @@ -219,21 +239,7 @@ def invoke(self, *args, **kwargs): ) raise AmazonBedrockInferenceError(msg) from exception - return responses - - @component.output_types(replies=List[str]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Generates a list of string response to the given prompt. - - :param prompt: The prompt to generate a response for. - :param generation_kwargs: Additional keyword arguments passed to the generator. - :returns: A dictionary with the following keys: - - `replies`: A list of generated responses. - :raises ValueError: If the prompt is empty or None. - :raises AmazonBedrockInferenceError: If the model cannot be invoked. - """ - return {"replies": self.invoke(prompt=prompt, **(generation_kwargs or {}))} + return {"replies": replies} @classmethod def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: @@ -255,6 +261,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( self, aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, @@ -264,6 +271,9 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, max_length=self.max_length, + truncate=self.truncate, + streaming_callback=callback_name, + **self.kwargs, ) @classmethod @@ -280,4 +290,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": data["init_parameters"], ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index f4dc1aa4f..07db2742f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from typing import Dict, Union from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast @@ -61,35 +60,3 @@ def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]: "model_max_length": self.model_max_length, "max_length": self.max_length, } - - -class TokenStreamingHandler(ABC): - """ - TokenStreamingHandler implementations handle the streaming of tokens from the stream. - """ - - DONE_MARKER = "[DONE]" - - @abstractmethod - def __call__(self, token_received: str, **kwargs) -> str: - """ - This callback method is called when a new token is received from the stream. - - :param token_received: The token received from the stream. - :param kwargs: Additional keyword arguments passed to the handler. - :returns: The token to be sent to the stream. - """ - pass - - -class DefaultTokenStreamingHandler(TokenStreamingHandler): - def __call__(self, token_received, **kwargs) -> str: - """ - This callback method is called when a new token is received from the stream. - - :param token_received: The token received from the stream. - :param kwargs: Additional keyword arguments passed to the handler. - :returns: The token to be sent to the stream. - """ - print(token_received, flush=True, end="") # noqa: T201 - return token_received diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index b44f0605e..64e9ce2ef 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,4 +1,7 @@ +import logging +import os from typing import Optional, Type +from unittest.mock import MagicMock, patch import pytest from haystack.components.generators.utils import print_streaming_chunk @@ -42,6 +45,7 @@ def test_to_dict(mock_boto3_session): "generation_kwargs": {"temperature": 0.7}, "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "truncate": True, }, } @@ -64,6 +68,7 @@ def test_from_dict(mock_boto3_session): "model": "anthropic.claude-v2", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "truncate": True, }, } ) @@ -82,7 +87,7 @@ def test_default_constructor(mock_boto3_session, set_env_variables): ) assert layer.model == "anthropic.claude-v2" - + assert layer.truncate is True assert layer.model_adapter.prompt_handler is not None assert layer.model_adapter.prompt_handler.model_max_length == 100000 @@ -108,6 +113,15 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) assert "temperature" in layer.model_adapter.generation_kwargs assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 + assert layer.model_adapter.truncate is True + + +def test_constructor_with_truncate(mock_boto3_session): + """ + Test that truncate param is correctly set in the model constructor + """ + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False) + assert layer.model_adapter.truncate is False def test_constructor_with_empty_model(): @@ -118,13 +132,121 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") -def test_invoke_with_no_kwargs(mock_boto3_session): +def test_short_prompt_is_not_truncated(mock_boto3_session): + """ + Test that a short prompt is not truncated + """ + # Define a short mock prompt and its tokenized version + mock_prompt_text = "I am a tokenized prompt" + mock_prompt_tokens = mock_prompt_text.split() + + # Mock the tokenizer so it returns our predefined tokens + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = mock_prompt_tokens + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Since our mock prompt is 5 tokens long, it doesn't exceed the + # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockChatGenerator( + "anthropic.claude-v2", + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text) + + # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it + assert prompt_after_resize == mock_prompt_text + + +def test_long_prompt_is_truncated(mock_boto3_session): """ - Test invoke raises an error if no messages are provided + Test that a long prompt is truncated """ - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2") - with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires"): - layer.invoke() + # Define a long mock prompt and its tokenized version + long_prompt_text = "I am a tokenized prompt of length eight" + long_prompt_tokens = long_prompt_text.split() + + # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit + truncated_prompt_text = "I am a tokenized prompt of length" + + # Mock the tokenizer to return our predefined tokens + # convert tokens to our predefined truncated text + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = long_prompt_tokens + mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockChatGenerator( + "anthropic.claude-v2", + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text) + + # The prompt exceeds the limit, _ensure_token_limit truncates it + assert prompt_after_resize == truncated_prompt_text + + +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + messages = [ChatMessage.from_system("What is the biggest city in United States?")] + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + truncate=False, + generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + + generator.model_adapter.get_responses = MagicMock( + return_value=[ + ChatMessage( + content="Some text", + role=ChatRole.ASSISTANT, + name=None, + meta=[ + { + "model": "claude-3-sonnet-20240229", + "index": 0, + "finish_reason": "end_turn", + "usage": {"prompt_tokens": 16, "completion_tokens": 55}, + } + ], + ) + ] + ) + # Invoke the generator + generator.run(messages=messages) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) @pytest.mark.parametrize( @@ -150,7 +272,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed class TestAnthropicClaudeAdapter: def test_prepare_body_with_default_params(self) -> None: - layer = AnthropicClaudeChatAdapter(generation_kwargs={}) + layer = AnthropicClaudeChatAdapter(truncate=True, generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { "anthropic_version": "bedrock-2023-05-31", @@ -163,7 +285,9 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) + layer = AnthropicClaudeChatAdapter( + truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} + ) prompt = "Hello, how are you?" expected_body = { "anthropic_version": "bedrock-2023-05-31", @@ -184,7 +308,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: - layer = MistralChatAdapter(generation_kwargs={}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { "max_tokens": 512, @@ -196,7 +320,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MistralChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", @@ -210,12 +334,12 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body def test_mistral_chat_template_correct_order(self): - layer = MistralChatAdapter(generation_kwargs={}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={}) layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")]) layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")]) def test_mistral_chat_template_incorrect_order(self): - layer = MistralChatAdapter(generation_kwargs={}) + layer = MistralChatAdapter(truncate=True, generation_kwargs={}) try: layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")]) msg = "Expected TemplateError" @@ -237,6 +361,33 @@ def test_mistral_chat_template_incorrect_order(self): except Exception as e: assert "Conversation roles must alternate user/assistant/" in str(e) + def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None: + monkeypatch.delenv("HF_TOKEN", raising=False) + with ( + patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), + caplog.at_level(logging.WARNING), + ): + MistralChatAdapter(truncate=True, generation_kwargs={}) + mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") + assert "no HF_TOKEN was found" in caplog.text + + def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: + monkeypatch.setenv("HF_TOKEN", "test") + with ( + patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), + ): + MistralChatAdapter(truncate=True, generation_kwargs={}) + mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") + + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason=( + "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " + "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" + ), + ) @pytest.mark.parametrize("model_name", MISTRAL_MODELS) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages): @@ -270,7 +421,7 @@ class TestMetaLlama2ChatAdapter: def test_prepare_body_with_default_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter(generation_kwargs={}) + layer = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) prompt = "Hello, how are you?" expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} @@ -283,7 +434,8 @@ def test_prepare_body_with_custom_inference_params(self) -> None: # leave this test as integration because we really need only tokenizer from HF # that way we can ensure prompt chat message formatting layer = MetaLlama2ChatAdapter( - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} + truncate=True, + generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]}, ) prompt = "Hello, how are you?" @@ -308,7 +460,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: @pytest.mark.integration def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(generation_kwargs={}) + adapter = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) response_body = {"generation": "This is a single response."} expected_response = "This is a single response." response_message = adapter.get_responses(response_body) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 10fc1eca8..f0233888c 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, call, patch import pytest +from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator from haystack_integrations.components.generators.amazon_bedrock.adapters import ( @@ -10,7 +11,8 @@ AnthropicClaudeAdapter, BedrockModelAdapter, CohereCommandAdapter, - MetaLlama2ChatAdapter, + CohereCommandRAdapter, + MetaLlamaAdapter, MistralAdapter, ) @@ -19,10 +21,7 @@ def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ - generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", - max_length=99, - ) + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", @@ -34,6 +33,9 @@ def test_to_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "truncate": False, + "temperature": 10, + "streaming_callback": None, }, } @@ -120,15 +122,6 @@ def test_constructor_with_empty_model(): AmazonBedrockGenerator(model="") -def test_invoke_with_no_kwargs(mock_boto3_session): - """ - Test invoke raises an error if no prompt is provided - """ - layer = AmazonBedrockGenerator(model="anthropic.claude-v2") - with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."): - layer.invoke() - - def test_short_prompt_is_not_truncated(mock_boto3_session): """ Test that a short prompt is not truncated @@ -193,6 +186,46 @@ def test_long_prompt_is_truncated(mock_boto3_session): assert prompt_after_resize == truncated_prompt_text +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + long_prompt_text = "I am a tokenized prompt of length eight" + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockGenerator( + model="anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, + truncate=False, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator, "_ensure_token_limit", wraps=generator._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + generator.model_adapter.get_responses = MagicMock(return_value=["response"]) + + # Invoke the generator + generator.run(prompt=long_prompt_text) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False) + + @pytest.mark.parametrize( "model, expected_model_adapter", [ @@ -203,6 +236,9 @@ def test_long_prompt_is_truncated(mock_boto3_session): ("cohere.command-text-v14", CohereCommandAdapter), ("cohere.command-light-text-v14", CohereCommandAdapter), ("cohere.command-text-v21", CohereCommandAdapter), # artificial + ("cohere.command-r-v1:0", CohereCommandRAdapter), + ("cohere.command-r-plus-v1:0", CohereCommandRAdapter), + ("cohere.command-r-v8:9", CohereCommandRAdapter), # artificial ("ai21.j2-mid-v1", AI21LabsJurassic2Adapter), ("ai21.j2-ultra-v1", AI21LabsJurassic2Adapter), ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial @@ -210,9 +246,16 @@ def test_long_prompt_is_truncated(mock_boto3_session): ("amazon.titan-text-express-v1", AmazonTitanAdapter), ("amazon.titan-text-agile-v1", AmazonTitanAdapter), ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial - ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("meta.llama2-13b-chat-v1", MetaLlamaAdapter), + ("meta.llama2-70b-chat-v1", MetaLlamaAdapter), + ("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial + ("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter), + ("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter), + ("meta.llama3-130b-instruct-v5:9", MetaLlamaAdapter), # artificial + ("mistral.mistral-7b-instruct-v0:2", MistralAdapter), + ("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), + ("mistral.mistral-large-2402-v1:0", MistralAdapter), + ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial ("unknown_model", None), ], ) @@ -225,9 +268,179 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed class TestAnthropicClaudeAdapter: + def test_default_init(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100) + assert adapter.use_messages_api is True + + def test_use_messages_api_false(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=100) + assert adapter.use_messages_api is False + + +class TestAnthropicClaudeAdapterMessagesAPI: def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 99, + "anthropic_version": "bedrock-2023-05-31", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "system prompt", + "anthropic_version": "custom_version", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens=50, + stop_sequences=["CUSTOM_STOP"], + system="system prompt", + anthropic_version="custom_version", + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "system": "system prompt", + "anthropic_version": "custom_version", + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "system prompt", + "anthropic_version": "custom_version", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.6, + "top_p": 0.7, + "top_k": 4, + "max_tokens": 49, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "system": "system prompt", + "anthropic_version": "custom_version", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "new system prompt", + "anthropic_version": "new_custom_version", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens=50, + system="new system prompt", + anthropic_version="new_custom_version", + ) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"content": [{"text": "This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"content": [{"text": "\n\t This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + streaming_callback_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"delta": {"text": " This"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " is"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " a"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " single"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " response."}}'}}, + ] + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses + + streaming_callback_mock.assert_has_calls( + [ + call(StreamingChunk(content=" This", meta={"delta": {"text": " This"}})), + call(StreamingChunk(content=" is", meta={"delta": {"text": " is"}})), + call(StreamingChunk(content=" a", meta={"delta": {"text": " a"}})), + call(StreamingChunk(content=" single", meta={"delta": {"text": " single"}})), + call(StreamingChunk(content=" response.", meta={"delta": {"text": " response."}})), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + streaming_callback_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses + + streaming_callback_mock.assert_not_called() + + +class TestAnthropicClaudeAdapterNoMessagesAPI: + def test_prepare_body_with_default_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) + prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", "max_tokens_to_sample": 99, @@ -239,7 +452,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", @@ -265,6 +478,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: def test_prepare_body_with_model_kwargs(self) -> None: layer = AnthropicClaudeAdapter( model_kwargs={ + "use_messages_api": False, "temperature": 0.7, "top_p": 0.8, "top_k": 5, @@ -291,6 +505,7 @@ def test_prepare_body_with_model_kwargs(self) -> None: def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: layer = AnthropicClaudeAdapter( model_kwargs={ + "use_messages_api": False, "temperature": 0.6, "top_p": 0.7, "top_k": 4, @@ -314,20 +529,20 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non assert body == expected_body def test_get_responses(self) -> None: - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "\n\t This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"completion": " This"}'}}, @@ -337,35 +552,31 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"completion": " response."}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"completion": " This"}), - call(" is", event_data={"completion": " is"}), - call(" a", event_data={"completion": " a"}), - call(" single", event_data={"completion": " single"}), - call(" response.", event_data={"completion": " response."}), + call(StreamingChunk(content=" This", meta={"completion": " This"})), + call(StreamingChunk(content=" is", meta={"completion": " is"})), + call(StreamingChunk(content=" a", meta={"completion": " a"})), + call(StreamingChunk(content=" single", meta={"completion": " single"})), + call(StreamingChunk(content=" response.", meta={"completion": " response."})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestMistralAdapter: @@ -460,7 +671,7 @@ def test_get_responses(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"outputs": [{"text": " This"}]}'}}, @@ -470,35 +681,33 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputs": [{"text": " response."}]}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = MistralAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"outputs": [{"text": " This"}]}), - call(" is", event_data={"outputs": [{"text": " is"}]}), - call(" a", event_data={"outputs": [{"text": " a"}]}), - call(" single", event_data={"outputs": [{"text": " single"}]}), - call(" response.", event_data={"outputs": [{"text": " response."}]}), + call(StreamingChunk(content=" This", meta={"outputs": [{"text": " This"}]})), + call(StreamingChunk(content=" is", meta={"outputs": [{"text": " is"}]})), + call(StreamingChunk(content=" a", meta={"outputs": [{"text": " a"}]})), + call(StreamingChunk(content=" single", meta={"outputs": [{"text": " single"}]})), + call(StreamingChunk(content=" response.", meta={"outputs": [{"text": " response."}]})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + streaming_callback_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MistralAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() class TestCohereCommandAdapter: @@ -655,7 +864,7 @@ def test_get_responses_multiple_responses(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"text": " This"}'}}, @@ -666,36 +875,140 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"text": " This"}), - call(" is", event_data={"text": " is"}), - call(" a", event_data={"text": " a"}), - call(" single", event_data={"text": " single"}), - call(" response.", event_data={"text": " response."}), - call("", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True}), + call(StreamingChunk(content=" This", meta={"text": " This"})), + call(StreamingChunk(content=" is", meta={"text": " is"})), + call(StreamingChunk(content=" a", meta={"text": " a"})), + call(StreamingChunk(content=" single", meta={"text": " single"})), + call(StreamingChunk(content=" response.", meta={"text": " response."})), + call(StreamingChunk(content="", meta={"finish_reason": "MAX_TOKENS", "is_finished": True})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() + + +class TestCohereCommandRAdapter: + def test_prepare_body(self) -> None: + adapter = CohereCommandRAdapter( + model_kwargs={ + "chat_history": [ + {"role": "CHATBOT", "content": "How can I help you today?"}, + ], + "documents": [ + {"title": "France", "snippet": "Paris is the capital of France."}, + {"title": "Germany", "snippet": "Berlin is the capital of Germany."}, + ], + "search_query_only": False, + "preamble": "preamble", + "temperature": 0, + "p": 0.9, + "k": 50, + "prompt_truncation": "AUTO_PRESERVE_ORDER", + "frequency_penalty": 0.3, + "presence_penalty": 0.4, + "seed": 42, + "return_prompt": True, + "tools": [ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales " + "information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True, + } + }, + } + ], + "tool_results": [ + { + "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, + "outputs": [ + {"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"} + ], + } + ], + "stop_sequences": ["\n\n"], + "raw_prompting": True, + "stream": True, + "unknown_arg": "unknown_arg", + }, + max_length=100, + ) + body = adapter.prepare_body(prompt="test") + assert body == { + "message": "test", + "chat_history": [ + {"role": "CHATBOT", "content": "How can I help you today?"}, + ], + "documents": [ + {"title": "France", "snippet": "Paris is the capital of France."}, + {"title": "Germany", "snippet": "Berlin is the capital of Germany."}, + ], + "search_query_only": False, + "preamble": "preamble", + "max_tokens": 100, + "temperature": 0, + "p": 0.9, + "k": 50, + "prompt_truncation": "AUTO_PRESERVE_ORDER", + "frequency_penalty": 0.3, + "presence_penalty": 0.4, + "seed": 42, + "return_prompt": True, + "tools": [ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales " + "information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True, + } + }, + } + ], + "tool_results": [ + { + "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, + "outputs": [{"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"}], + } + ], + "stop_sequences": ["\n\n"], + "raw_prompting": True, + } + + def test_extract_completions_from_response(self) -> None: + adapter = CohereCommandRAdapter(model_kwargs={}, max_length=100) + response_body = {"text": "response"} + completions = adapter._extract_completions_from_response(response_body=response_body) + assert completions == ["response"] + + def test_build_chunk(self) -> None: + adapter = CohereCommandRAdapter(model_kwargs={}, max_length=100) + chunk = {"text": "response_token"} + streaming_chunk = adapter._build_streaming_chunk(chunk=chunk) + assert streaming_chunk == StreamingChunk(content="response_token", meta=chunk) class TestAI21LabsJurassic2Adapter: @@ -954,7 +1267,7 @@ def test_get_responses_multiple_responses(self) -> None: def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"outputText": " This"}'}}, @@ -964,40 +1277,36 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputText": " response."}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"outputText": " This"}), - call(" is", event_data={"outputText": " is"}), - call(" a", event_data={"outputText": " a"}), - call(" single", event_data={"outputText": " single"}), - call(" response.", event_data={"outputText": " response."}), + call(StreamingChunk(content=" This", meta={"outputText": " This"})), + call(StreamingChunk(content=" is", meta={"outputText": " is"})), + call(StreamingChunk(content=" a", meta={"outputText": " a"})), + call(StreamingChunk(content=" single", meta={"outputText": " single"})), + call(StreamingChunk(content=" response.", meta={"outputText": " response."})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() -class TestMetaLlama2ChatAdapter: +class TestMetaLlamaAdapter: def test_prepare_body_with_default_params(self) -> None: - layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + layer = MetaLlamaAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 99} @@ -1006,7 +1315,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + layer = MetaLlamaAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "Hello, how are you?", @@ -1026,7 +1335,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body def test_prepare_body_with_model_kwargs(self) -> None: - layer = MetaLlama2ChatAdapter( + layer = MetaLlamaAdapter( model_kwargs={ "temperature": 0.7, "top_p": 0.8, @@ -1048,7 +1357,7 @@ def test_prepare_body_with_model_kwargs(self) -> None: assert body == expected_body def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: - layer = MetaLlama2ChatAdapter( + layer = MetaLlamaAdapter( model_kwargs={ "temperature": 0.6, "top_p": 0.7, @@ -1070,20 +1379,20 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non assert body == expected_body def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) response_body = {"generation": "This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) response_body = {"generation": "\n\t This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [ {"chunk": {"bytes": b'{"generation": " This"}'}}, @@ -1093,32 +1402,28 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"generation": " response."}'}}, ] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_has_calls( + streaming_callback_mock.assert_has_calls( [ - call(" This", event_data={"generation": " This"}), - call(" is", event_data={"generation": " is"}), - call(" a", event_data={"generation": " a"}), - call(" single", event_data={"generation": " single"}), - call(" response.", event_data={"generation": " response."}), + call(StreamingChunk(content=" This", meta={"generation": " This"})), + call(StreamingChunk(content=" is", meta={"generation": " is"})), + call(StreamingChunk(content=" a", meta={"generation": " a"})), + call(StreamingChunk(content=" single", meta={"generation": " single"})), + call(StreamingChunk(content=" response.", meta={"generation": " response."})), ] ) def test_get_stream_responses_empty(self) -> None: stream_mock = MagicMock() - stream_handler_mock = MagicMock() + streaming_callback_mock = MagicMock() stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - - adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + assert adapter.get_stream_responses(stream_mock, streaming_callback_mock) == expected_responses - stream_handler_mock.assert_not_called() + streaming_callback_mock.assert_not_called() diff --git a/integrations/amazon_sagemaker/CHANGELOG.md b/integrations/amazon_sagemaker/CHANGELOG.md new file mode 100644 index 000000000..edd15fc82 --- /dev/null +++ b/integrations/amazon_sagemaker/CHANGELOG.md @@ -0,0 +1,27 @@ +# Changelog + +## [unreleased] + +### 🚀 Features + +- Sagemaker integration: `SagemakerGenerator` (#276) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Small consistency improvements (#536) +- Review integrations sagemaker (#544) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + + diff --git a/integrations/amazon_sagemaker/pydoc/config.yml b/integrations/amazon_sagemaker/pydoc/config.yml index 20d51b25e..950e949f7 100644 --- a/integrations/amazon_sagemaker/pydoc/config.yml +++ b/integrations/amazon_sagemaker/pydoc/config.yml @@ -14,7 +14,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Amazon Sagemaker integration for Haystack category_slug: integrations-api title: Amazon Sagemaker diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index a1f6ff239..f8050bb48 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -13,9 +13,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -28,10 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "boto3>=1.28.57", -] +dependencies = ["haystack-ai", "boto3>=1.28.57"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack#readme" @@ -53,49 +48,30 @@ git_describe_command = 'git describe --tags --match="integrations/amazon_sagemak dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -140,9 +116,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -166,12 +148,8 @@ parallel = true [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -185,7 +163,6 @@ module = [ ignore_missing_imports = true - [tool.pytest.ini_options] addopts = "--strict-markers" markers = [ diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py index 5b67e3835..7e23bb7e7 100644 --- a/integrations/amazon_sagemaker/tests/test_sagemaker.py +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -1,5 +1,5 @@ import os -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from botocore.exceptions import BotoCoreError @@ -83,14 +83,13 @@ def test_default_constructor(set_env_variables, mock_boto3_session): # noqa: AR def test_init_raises_boto_error(set_env_variables, mock_boto3_session): # noqa: ARG001 - with patch("boto3.Session") as mock_boto3_session: - mock_boto3_session.side_effect = BotoCoreError() - with pytest.raises( - AWSConfigurationError, - match="Could not connect to SageMaker Inference Endpoint 'test-model'." - "Make sure the Endpoint exists and AWS environment is configured.", - ): - SagemakerGenerator(model="test-model") + mock_boto3_session.side_effect = BotoCoreError() + with pytest.raises( + AWSConfigurationError, + match="Could not connect to SageMaker Inference Endpoint 'test-model'." + "Make sure the Endpoint exists and AWS environment is configured.", + ): + SagemakerGenerator(model="test-model") def test_run_with_list_of_dictionaries(set_env_variables, mock_boto3_session): # noqa: ARG001 diff --git a/integrations/anthropic/CHANGELOG.md b/integrations/anthropic/CHANGELOG.md new file mode 100644 index 000000000..450fe570a --- /dev/null +++ b/integrations/anthropic/CHANGELOG.md @@ -0,0 +1,28 @@ +# Changelog + +## [integrations/anthropic-v0.4.1] - 2024-07-17 + +### ⚙️ Miscellaneous Tasks + +- Update ruff invocation to include check parameter (#853) + +## [integrations/anthropic-v0.4.0] - 2024-06-21 + +### 🚀 Features + +- Update Anthropic/Cohere for tools use (#790) +- Update Anthropic default models, pydocs (#839) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) + +## [integrations/anthropic-v0.2.0] - 2024-03-15 + +## [integrations/anthropic-v0.1.0] - 2024-03-15 + +### 🚀 Features + +- Add AnthropicGenerator and AnthropicChatGenerator (#573) + + diff --git a/integrations/anthropic/example/documentation_rag_with_claude.py b/integrations/anthropic/example/documentation_rag_with_claude.py index 98cc9d40e..eb7ec2ad0 100644 --- a/integrations/anthropic/example/documentation_rag_with_claude.py +++ b/integrations/anthropic/example/documentation_rag_with_claude.py @@ -1,7 +1,7 @@ # To run this example, you will need to set a `ANTHROPIC_API_KEY` environment variable. from haystack import Pipeline -from haystack.components.builders import DynamicChatPromptBuilder +from haystack.components.builders import ChatPromptBuilder from haystack.components.converters import HTMLToDocument from haystack.components.fetchers import LinkContentFetcher from haystack.components.generators.utils import print_streaming_chunk @@ -18,12 +18,11 @@ rag_pipeline = Pipeline() rag_pipeline.add_component("fetcher", LinkContentFetcher()) rag_pipeline.add_component("converter", HTMLToDocument()) -rag_pipeline.add_component("prompt_builder", DynamicChatPromptBuilder(runtime_variables=["documents"])) +rag_pipeline.add_component("prompt_builder", ChatPromptBuilder()) rag_pipeline.add_component( "llm", AnthropicChatGenerator( api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), - model="claude-3-sonnet-20240229", streaming_callback=print_streaming_chunk, ), ) @@ -31,12 +30,12 @@ rag_pipeline.connect("fetcher", "converter") rag_pipeline.connect("converter", "prompt_builder") -rag_pipeline.connect("prompt_builder", "llm") +rag_pipeline.connect("prompt_builder.prompt", "llm.messages") question = "What are the best practices in prompt engineering?" rag_pipeline.run( data={ "fetcher": {"urls": ["https://docs.anthropic.com/claude/docs/prompt-engineering"]}, - "prompt_builder": {"template_variables": {"query": question}, "prompt_source": messages}, + "prompt_builder": {"template_variables": {"query": question}, "template": messages}, } ) diff --git a/integrations/anthropic/pydoc/config.yml b/integrations/anthropic/pydoc/config.yml index 553dfcaef..9c1e39daf 100644 --- a/integrations/anthropic/pydoc/config.yml +++ b/integrations/anthropic/pydoc/config.yml @@ -15,12 +15,12 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Anthropic integration for Haystack category_slug: integrations-api title: Anthropic slug: integrations-anthropic - order: 22 + order: 23 markdown: descriptive_class_title: false descriptive_module_title: true diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index b04a54258..3f8c9812b 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -25,10 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "anthropic", -] +dependencies = ["haystack-ai", "anthropic"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/anthropic#readme" @@ -50,50 +45,31 @@ git_describe_command = 'git describe --tags --match="integrations/anthropic-v[0- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] -all = [ - "style", - "typing", -] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -133,11 +109,19 @@ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", # Ignore unused params - "ARG001", "ARG002", "ARG005", + "ARG001", + "ARG002", + "ARG005", ] unfixable = [ # Don't touch unused imports @@ -161,12 +145,8 @@ parallel = true [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -180,8 +160,5 @@ ignore_missing_imports = true [tool.pytest.ini_options] addopts = "--strict-markers" -markers = [ - "unit: unit tests", - "integration: integration tests", -] +markers = ["unit: unit tests", "integration: integration tests"] log_cli = true diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 6f43855b7..9954f08c5 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -1,4 +1,5 @@ import dataclasses +import json from typing import Any, Callable, ClassVar, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -7,13 +8,14 @@ from anthropic import Anthropic, Stream from anthropic.types import ( - ContentBlock, ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent, MessageStreamEvent, + TextBlock, TextDelta, + ToolUseBlock, ) logger = logging.getLogger(__name__) @@ -25,21 +27,21 @@ class AnthropicChatGenerator: Enables text generation using Anthropic state-of-the-art Claude 3 family of large language models (LLMs) through the Anthropic messaging API. - It supports models like `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku`, accessed through the - `/v1/messages` API endpoint using the Claude v2.1 messaging version. + It supports models like `claude-3-5-sonnet`, `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku`, + accessed through the [`/v1/messages`](https://docs.anthropic.com/en/api/messages) API endpoint. Users can pass any text generation parameters valid for the Anthropic messaging API directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs` parameter in the `run` method. For more details on the parameters supported by the Anthropic API, refer to the - Anthropic Message API [documentation](https://docs.anthropic.com/claude/reference/messages_post). + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). ```python from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator from haystack.dataclasses import ChatMessage messages = [ChatMessage.from_user("What's Natural Language Processing?")] - client = AnthropicChatGenerator(model="claude-3-sonnet-20240229") + client = AnthropicChatGenerator(model="claude-3-5-sonnet-20240620") response = client.run(messages) print(response) @@ -47,20 +49,23 @@ class AnthropicChatGenerator: >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and >> communicate in natural languages like English, Spanish, or Chinese.', role=, - >> name=None, meta={'model': 'claude-3-sonnet-20240229', 'index': 0, 'finish_reason': 'end_turn', + >> name=None, meta={'model': 'claude-3-5-sonnet-20240620', 'index': 0, 'finish_reason': 'end_turn', >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} ``` For more details on supported models and their capabilities, refer to the Anthropic [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). - Note: We don't yet support vision [capabilities](https://docs.anthropic.com/claude/docs/vision) in the current - implementation. + Note: We only support text input/output modalities, and + image [modality](https://docs.anthropic.com/en/docs/build-with-claude/vision) is not supported in + this version of AnthropicChatGenerator. """ # The parameters that can be passed to the Anthropic API https://docs.anthropic.com/claude/reference/messages_post ALLOWED_PARAMS: ClassVar[List[str]] = [ "system", + "tools", + "tool_choice", "max_tokens", "metadata", "stop_sequences", @@ -72,9 +77,10 @@ class AnthropicChatGenerator: def __init__( self, api_key: Secret = Secret.from_env_var("ANTHROPIC_API_KEY"), # noqa: B008 - model: str = "claude-3-sonnet-20240229", + model: str = "claude-3-5-sonnet-20240620", streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, ): """ Creates an instance of AnthropicChatGenerator. @@ -95,13 +101,18 @@ def __init__( - `temperature`: The temperature to use for sampling. - `top_p`: The top_p value to use for nucleus sampling. - `top_k`: The top_k value to use for top-k sampling. - + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. """ self.api_key = api_key self.model = model self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback self.client = Anthropic(api_key=self.api_key.resolve_value()) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -123,6 +134,7 @@ def to_dict(self) -> Dict[str, Any]: streaming_callback=callback_name, generation_kwargs=self.generation_kwargs, api_key=self.api_key.to_dict(), + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, ) @classmethod @@ -201,20 +213,33 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # capture stop reason and stop sequence delta = stream_event completions = [self._connect_chunks(chunks, start_event, delta)] + # if streaming is disabled, the response is an Anthropic Message elif isinstance(response, Message): + has_tools_msgs = any(isinstance(content_block, ToolUseBlock) for content_block in response.content) + if has_tools_msgs and self.ignore_tools_thinking_messages: + response.content = [block for block in response.content if isinstance(block, ToolUseBlock)] completions = [self._build_message(content_block, response) for content_block in response.content] + # rename the meta key to be inline with OpenAI meta output keys + for response in completions: + if response.meta is not None and "usage" in response.meta: + response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") + response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") + return {"replies": completions} - def _build_message(self, content_block: ContentBlock, message: Message) -> ChatMessage: + def _build_message(self, content_block: Union[TextBlock, ToolUseBlock], message: Message) -> ChatMessage: """ Converts the non-streaming Anthropic Message to a ChatMessage. :param content_block: The content block of the message. :param message: The non-streaming Anthropic Message. :returns: The ChatMessage. """ - chat_message = ChatMessage.from_assistant(content_block.text) + if isinstance(content_block, TextBlock): + chat_message = ChatMessage.from_assistant(content_block.text) + else: + chat_message = ChatMessage.from_assistant(json.dumps(content_block.model_dump(mode="json"))) chat_message.meta.update( { "model": message.model, diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py index aa78dfed1..4cb8fd3e6 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py @@ -29,7 +29,7 @@ class AnthropicGenerator: ```python from haystack_integrations.components.generators.anthropic import AnthropicGenerator - client = AnthropicGenerator(model="claude-2.1") + client = AnthropicGenerator(model="claude-3-sonnet-20240229") response = client.run("What's Natural Language Processing? Be brief.") print(response) >>{'replies': ['Natural language processing (NLP) is a branch of artificial intelligence focused on enabling diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 41cc3eb5d..3ffa24c94 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -1,3 +1,4 @@ +import json import os import anthropic @@ -22,9 +23,10 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") component = AnthropicChatGenerator() assert component.client.api_key == "test-api-key" - assert component.model == "claude-3-sonnet-20240229" + assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is None assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) @@ -34,14 +36,16 @@ def test_init_fail_wo_api_key(self, monkeypatch): def test_init_with_parameters(self): component = AnthropicChatGenerator( api_key=Secret.from_token("test-api-key"), - model="claude-3-sonnet-20240229", + model="claude-3-5-sonnet-20240620", streaming_callback=print_streaming_chunk, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, ) assert component.client.api_key == "test-api-key" - assert component.model == "claude-3-sonnet-20240229" + assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") @@ -51,9 +55,10 @@ def test_to_dict_default(self, monkeypatch): "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, - "model": "claude-3-sonnet-20240229", + "model": "claude-3-5-sonnet-20240620", "streaming_callback": None, "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, }, } @@ -69,16 +74,17 @@ def test_to_dict_with_parameters(self, monkeypatch): "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, - "model": "claude-3-sonnet-20240229", + "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") component = AnthropicChatGenerator( - model="claude-3-sonnet-20240229", + model="claude-3-5-sonnet-20240620", streaming_callback=lambda x: x, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) @@ -87,9 +93,10 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, - "model": "claude-3-sonnet-20240229", + "model": "claude-3-5-sonnet-20240620", "streaming_callback": "tests.test_chat_generator.", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } @@ -99,13 +106,14 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, - "model": "claude-3-sonnet-20240229", + "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } component = AnthropicChatGenerator.from_dict(data) - assert component.model == "claude-3-sonnet-20240229" + assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("ANTHROPIC_API_KEY") @@ -116,9 +124,10 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, - "model": "claude-3-sonnet-20240229", + "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, }, } with pytest.raises(ValueError, match="None of the .* environment variables are set"): @@ -216,3 +225,40 @@ def streaming_callback(chunk: StreamingChunk): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_tools_use(self): + # See https://docs.anthropic.com/en/docs/tool-use for more information + tools_schema = { + "name": "get_stock_price", + "description": "Retrieves the current stock price for a given ticker symbol.", + "input_schema": { + "type": "object", + "properties": { + "ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."} + }, + "required": ["ticker"], + }, + } + client = AnthropicChatGenerator() + response = client.run( + messages=[ChatMessage.from_user("What is the current price of AAPL?")], + generation_kwargs={"tools": [tools_schema]}, + ) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" + assert first_reply.meta, "First reply has no metadata" + fc_response = json.loads(first_reply.content) + assert "name" in fc_response, "First reply does not contain name of the tool" + assert "input" in fc_response, "First reply does not contain input of the tool" diff --git a/integrations/astra/CHANGELOG.md b/integrations/astra/CHANGELOG.md new file mode 100644 index 000000000..55c22f540 --- /dev/null +++ b/integrations/astra/CHANGELOG.md @@ -0,0 +1,91 @@ +# Changelog + +## [integrations/astra-v0.9.2] - 2024-07-22 + +## [integrations/astra-v0.9.1] - 2024-07-15 + +### 🚀 Features + +- Defer the database connection to when it's needed (#769) +- Add filter_policy to astra integration (#827) + +### 🐛 Bug Fixes + +- Fix astra nightly + +- Fix typing checks + +- `Astra` - Fallback to default filter policy when deserializing retrievers without the init parameter (#896) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) + +## [integrations/astra-v0.7.0] - 2024-05-15 + +### 🐛 Bug Fixes + +- Make unit tests pass (#720) + +## [integrations/astra-v0.6.0] - 2024-04-24 + +### 🐛 Bug Fixes + +- Pass namespace in the docstore init (#683) + +## [integrations/astra-v0.5.1] - 2024-04-09 + +### 🐛 Bug Fixes + +- Fix haystack-ai pin (#649) + + + +## [integrations/astra-v0.5.0] - 2024-03-18 + +### 📚 Documentation + +- Review `integrations.astra` (#498) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +## [integrations/astra-v0.4.2] - 2024-02-21 + +### FIX + +- Proper name for the sort param (#454) + +## [integrations/astra-v0.4.1] - 2024-02-20 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Fix integration tests (#450) + + + +## [integrations/astra-v0.4.0] - 2024-02-20 + +### 📚 Documentation + +- Update category slug (#442) + +## [integrations/astra-v0.3.0] - 2024-02-15 + +## [integrations/astra-v0.2.0] - 2024-02-13 + +### Astra + +- Generate api docs (#327) + +### Refact + +- [**breaking**] Change import paths (#277) + +## [integrations/astra-v0.1.1] - 2024-01-18 + +## [integrations/astra-v0.1.0] - 2024-01-11 + + diff --git a/integrations/astra/README.md b/integrations/astra/README.md index f8b6f7c31..f679b7207 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -24,8 +24,8 @@ pyenv local 3.9 Local install for the package `pip install -e .` To execute integration tests, add needed environment variables -`ASTRA_DB_API_ENDPOINT=` -`ASTRA_DB_APPLICATION_TOKEN=` +`ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com"`, +`ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."` and execute `python examples/example.py` @@ -34,10 +34,10 @@ Install requirements Export environment variables ``` -export ASTRA_DB_API_ENDPOINT= -export ASTRA_DB_APPLICATION_TOKEN= -export COLLECTION_NAME= -export OPENAI_API_KEY= +export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com" +export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." +export COLLECTION_NAME="my_collection" +export OPENAI_API_KEY="sk-..." ``` run the python examples @@ -59,19 +59,17 @@ from haystack.document_stores.types.policy import DuplicatePolicy Load in environment variables: ``` -api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT", "") -token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") -collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") +namespace = os.environ.get("ASTRA_DB_KEYSPACE") +collection_name = os.environ.get("COLLECTION_NAME", "haystack_vector_search") ``` -Create the Document Store object: +Create the Document Store object (API Endpoint and Token are read off the environment): ``` document_store = AstraDocumentStore( - api_endpoint=api_endpoint, - token=token, collection_name=collection_name, + namespace=namespace, duplicates_policy=DuplicatePolicy.SKIP, - embedding_dim=384, + embedding_dimension=384, ) ``` @@ -92,3 +90,31 @@ Add your AstraEmbeddingRetriever into the pipeline Add other components and connect them as desired. Then run your pipeline: `pipeline.run(...)` + +## Warnings about indexing + +When creating an Astra DB document store, you may see a warning similar to the following: + +> Astra DB collection '...' is detected as having indexing turned on for all fields (either created manually or by older versions of this plugin). This implies stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. + +or, + +> Astra DB collection '...' is detected as having the following indexing policy: {...}. This does not match the requested indexing policy for this object: {...}. In particular, there may be stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. + + +The reason for the warning is that the requested collection already exists on the database, and it is configured to [index all of its fields for search](https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option), possibly implicitly, by default. When the Haystack object tries to create it, it attempts to enforce, instead, an indexing policy tailored to the prospected usage: this is both to enable storing very long texts and to avoid indexing fields that will never be used in filtering a search (indexing those would also have a slight performance cost for writes). + +Typically there are two reasons why you may encounter the warning: + +1. you have created a collection by other means than letting this component do it for you: for example, through the Astra UI, or using AstraPy's `create_collection` method of class `Database` directly; +2. you have created the collection with an older version of the plugin. + +Keep in mind that this is a warning and your application will continue running just fine, as long as you don't store very long texts. +However, should you need to add to the document store, for example, a document with a very long textual content, you will get an indexing error from the database. + +### Remediation + +You have several options: + +- you can ignore the warning because you know your application will never need to store very long textual contents; +- if you can afford populating the collection anew, you can drop it and re-run the Haystack application: the collection will be created with the optimized indexing settings. **This is the recommended option, when possible**. diff --git a/integrations/astra/pydoc/config.yml b/integrations/astra/pydoc/config.yml index 61fec0523..ed35427e6 100644 --- a/integrations/astra/pydoc/config.yml +++ b/integrations/astra/pydoc/config.yml @@ -16,7 +16,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Astra integration for Haystack category_slug: integrations-api title: Astra diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 3a36df238..7d543ddc9 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "Anant Corporation", email = "support@anant.us" }, -] +authors = [{ name = "Anant Corporation", email = "support@anant.us" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -24,12 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "pydantic", - "typing_extensions", - "astrapy", -] +dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra#readme" @@ -51,47 +44,28 @@ git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*" dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -104,7 +78,7 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -select = [ +lint.select = [ "A", "ARG", "B", @@ -131,29 +105,35 @@ select = [ "W", "YTT", ] -ignore = [ +lint.ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] -unfixable = [ +lint.unfixable = [ # Don't touch unused imports "F401", ] -exclude = ["example"] +lint.exclude = ["example"] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] @@ -165,20 +145,13 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] minversion = "6.0" -markers = [ - "unit: unit tests", - "integration: integration tests" -] +markers = ["unit: unit tests", "integration: integration tests"] [[tool.mypy.overrides]] module = [ @@ -187,6 +160,7 @@ module = [ "pydantic.*", "haystack.*", "haystack_integrations.*", - "pytest.*" + "pytest.*", + "openpyxl.*", ] ignore_missing_imports = true diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index 80e436e0a..cfa45e81f 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -2,9 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.astra import AstraDocumentStore @@ -31,14 +33,25 @@ class AstraEmbeddingRetriever: ``` """ - def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): + def __init__( + self, + document_store: AstraDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): """ + :param document_store: An instance of AstraDocumentStore. :param filters: a dictionary with filters to narrow down the search space. :param top_k: the maximum number of documents to retrieve. + :param filter_policy: Policy to determine how filters are applied. """ - self.filters = filters + self.filters = filters or {} self.top_k = top_k self.document_store = document_store + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) if not isinstance(document_store, AstraDocumentStore): message = "document_store must be an instance of AstraDocumentStore" @@ -49,17 +62,15 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = """Retrieve documents from the AstraDocumentStore. :param query_embedding: floats representing the query embedding - :param filters: filters to narrow down the search space. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: the maximum number of documents to retrieve. :returns: a dictionary with the following keys: - `documents`: A list of documents retrieved from the AstraDocumentStore. """ - - if not top_k: - top_k = self.top_k - - if not filters: - filters = self.filters + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k return {"documents": self.document_store.search(query_embedding, top_k, filters=filters)} @@ -74,6 +85,7 @@ def to_dict(self) -> Dict[str, Any]: self, filters=self.filters, top_k=self.top_k, + filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -89,4 +101,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever": """ document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index fc5f4b6c9..5a88a0fe9 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -74,12 +74,13 @@ def __init__( caller_version=integration_version, ) + indexing_options = {"indexing": {"deny": NON_INDEXED_FIELDS}} try: # Create and connect to the newly created collection self._astra_db_collection = self._astra_db.create_collection( collection_name=collection_name, dimension=embedding_dimension, - options={"indexing": {"deny": NON_INDEXED_FIELDS}}, + options=indexing_options, ) except APIRequestError: # possibly the collection is preexisting and has legacy @@ -98,11 +99,16 @@ def __init__( if "indexing" not in pre_col_options: warn( ( - f"Collection '{collection_name}' is detected as legacy" - " and has indexing turned on for all fields. This" - " implies stricter limitations on the amount of text" - " each entry can store. Consider reindexing anew on a" - " fresh collection to be able to store longer texts." + f"Astra DB collection '{collection_name}' is " + "detected as having indexing turned on for all " + "fields (either created manually or by older " + "versions of this plugin). This implies stricter " + "limitations on the amount of text each string in a " + "document can store. Consider indexing anew on a " + "fresh collection to be able to store longer texts. " + "See https://github.com/deepset-ai/haystack-core-" + "integrations/blob/main/integrations/astra/README" + ".md#warnings-about-indexing for more details." ), UserWarning, stacklevel=2, @@ -110,16 +116,22 @@ def __init__( self._astra_db_collection = self._astra_db.collection( collection_name=collection_name, ) - else: - options_json = json.dumps(pre_col_options["indexing"]) + elif pre_col_options["indexing"] != indexing_options["indexing"]: + detected_options_json = json.dumps(pre_col_options["indexing"]) + indexing_options_json = json.dumps(indexing_options["indexing"]) warn( ( - f"Collection '{collection_name}' has unexpected 'indexing'" - f" settings (options.indexing = {options_json})." - " This can result in odd behaviour when running " - " metadata filtering and/or unwarranted limitations" - " on storing long texts. Consider reindexing anew on a" - " fresh collection." + f"Astra DB collection '{collection_name}' is " + "detected as having the following indexing policy: " + f"{detected_options_json}. This does not match the requested " + f"indexing policy for this object: {indexing_options_json}. " + "In particular, there may be stricter " + "limitations on the amount of text each string in a " + "document can store. Consider indexing anew on a " + "fresh collection to be able to store longer texts. " + "See https://github.com/deepset-ai/haystack-core-" + "integrations/blob/main/integrations/astra/README" + ".md#warnings-about-indexing for more details." ), UserWarning, stacklevel=2, @@ -127,6 +139,9 @@ def __init__( self._astra_db_collection = self._astra_db.collection( collection_name=collection_name, ) + else: + # the collection mismatch lies elsewhere than the indexing + raise else: # other exception raise @@ -208,6 +223,7 @@ def find_documents(self, find_query): filter=find_query.get("filter"), sort=find_query.get("sort"), options=find_query.get("options"), + projection={"*": 1}, ) if "data" in response_dict and "documents" in response_dict["data"]: @@ -273,6 +289,7 @@ def update_document(self, document: Dict, id_key: str): filter={id_key: document_id}, update={"$set": document}, options={"returnDocument": "after"}, + projection={"*": 1}, ) document[id_key] = document_id diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index fdbf95eb0..1dea6e08b 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -85,6 +85,7 @@ def __init__( "Set the ASTRA_DB_API_ENDPOINT environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) + self.resolved_api_endpoint = resolved_api_endpoint resolved_token = token.resolve_value() if resolved_token is None: @@ -93,6 +94,7 @@ def __init__( "Set the ASTRA_DB_APPLICATION_TOKEN environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) + self.resolved_token = resolved_token self.api_endpoint = api_endpoint self.token = token @@ -101,15 +103,20 @@ def __init__( self.duplicates_policy = duplicates_policy self.similarity = similarity self.namespace = namespace - - self.index = AstraClient( - resolved_api_endpoint, - resolved_token, - self.collection_name, - self.embedding_dimension, - self.similarity, - namespace, - ) + self._index: Optional[AstraClient] = None + + @property + def index(self) -> AstraClient: + if self._index is None: + self._index = AstraClient( + self.resolved_api_endpoint, + self.resolved_token, + self.collection_name, + self.embedding_dimension, + self.similarity, + self.namespace, + ) + return self._index @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore": diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py index 44cac25e6..61f3e5402 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py @@ -52,7 +52,7 @@ def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[ # TODO consider other operators, or filters that are not with the same structure as field operator value OPERATORS = { "==": "$eq", - "!=": "$neq", + "!=": "$ne", ">": "$gt", ">=": "$gte", "<": "$lt", @@ -73,7 +73,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) operator = condition["operator"] - conditions = [_parse_comparison_condition(c) for c in condition["conditions"]] + conditions = [_normalize_filters(c) for c in condition["conditions"]] if len(conditions) > 1: conditions = _normalize_ranges(conditions) if operator not in OPERATORS: diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index 3650ffd61..df181ad8c 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -14,18 +14,30 @@ from haystack_integrations.document_stores.astra import AstraDocumentStore -def test_namespace_init(): +@pytest.fixture +def mock_auth(monkeypatch): + monkeypatch.setenv("ASTRA_DB_API_ENDPOINT", "http://example.com") + monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "test_token") + + +@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") +def test_init_is_lazy(_mock_client, mock_auth): # noqa + _ = AstraDocumentStore() + _mock_client.assert_not_called() + + +def test_namespace_init(mock_auth): # noqa with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: - AstraDocumentStore() + _ = AstraDocumentStore().index assert "namespace" in client.call_args.kwargs assert client.call_args.kwargs["namespace"] is None - AstraDocumentStore(namespace="foo") + _ = AstraDocumentStore(namespace="foo").index assert "namespace" in client.call_args.kwargs assert client.call_args.kwargs["namespace"] == "foo" -def test_to_dict(): +def test_to_dict(mock_auth): # noqa with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): ds = AstraDocumentStore() result = ds.to_dict() @@ -160,6 +172,34 @@ def test_delete_documents_more_than_twenty_delete_ids(self, document_store: Astr # No Document has been deleted assert document_store.count_documents() == 0 + def test_filter_documents_nested_filters(self, document_store, filterable_docs): + filter_criteria = { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "100"}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.chapter", "operator": "==", "value": "abstract"}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + ], + } + + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters=filter_criteria) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("page") == "100" + and (d.meta.get("chapter") == "abstract" or d.meta.get("chapter") == "intro") + ], + ) + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index b52cedf33..4ffe30919 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -3,10 +3,36 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import patch +import pytest +from haystack.document_stores.types import FilterPolicy + from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever from haystack_integrations.document_stores.astra import AstraDocumentStore +@patch.dict( + "os.environ", + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, +) +@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") +def test_retriever_init(*_): + ds = AstraDocumentStore() + retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace") + assert retriever.filters == {"foo": "bar"} + assert retriever.top_k == 99 + assert retriever.document_store == ds + assert retriever.filter_policy == FilterPolicy.REPLACE + + retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy=FilterPolicy.MERGE) + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown") + + with pytest.raises(ValueError): + AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy=None) + + @patch.dict( "os.environ", {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, @@ -21,6 +47,7 @@ def test_retriever_to_json(*_): "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, + "filter_policy": "replace", "document_store": { "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { @@ -43,6 +70,36 @@ def test_retriever_to_json(*_): ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json(*_): + data = { + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", + "init_parameters": { + "filters": {"bar": "baz"}, + "top_k": 42, + "filter_policy": "replace", + "document_store": { + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", + "init_parameters": { + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, + "collection_name": "documents", + "embedding_dimension": 768, + "duplicates_policy": "NONE", + "similarity": "cosine", + }, + }, + }, + } + retriever = AstraEmbeddingRetriever.from_dict(data) + assert retriever.top_k == 42 + assert retriever.filters == {"bar": "baz"} + + +@patch.dict( + "os.environ", + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, +) +@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") +def test_retriever_from_json_no_filter_policy(*_): data = { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": { @@ -64,3 +121,4 @@ def test_retriever_from_json(*_): retriever = AstraEmbeddingRetriever.from_dict(data) assert retriever.top_k == 42 assert retriever.filters == {"bar": "baz"} + assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE diff --git a/integrations/chroma/CHANGELOG.md b/integrations/chroma/CHANGELOG.md new file mode 100644 index 000000000..f6a23d84a --- /dev/null +++ b/integrations/chroma/CHANGELOG.md @@ -0,0 +1,100 @@ +# Changelog + +## [integrations/chroma-v0.21.1] - 2024-07-17 + +### 🐛 Bug Fixes + +- `ChromaDocumentStore` - discard `meta` items when the type of their value is not supported in Chroma (#907) + +## [integrations/chroma-v0.21.0] - 2024-07-16 + +### 🚀 Features + +- Add metadata parameter to ChromaDocumentStore. (#906) + +## [integrations/chroma-v0.20.1] - 2024-07-15 + +### 🚀 Features + +- Added distance_function property to ChromadocumentStore (#817) +- Add filter_policy to chroma integration (#826) + +### 🐛 Bug Fixes + +- Allow search in ChromaDocumentStore without metadata (#863) +- `Chroma` - Fallback to default filter policy when deserializing retrievers without the init parameter (#897) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/chroma-v0.18.0] - 2024-05-31 + +## [integrations/chroma-v0.17.0] - 2024-05-10 + +## [integrations/chroma-v0.16.0] - 2024-05-02 + +### 📚 Documentation + +- Small consistency improvements (#536) +- Disable-class-def (#556) + +## [integrations/chroma-v0.15.0] - 2024-03-01 + +## [integrations/chroma-v0.14.0] - 2024-02-29 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Serialize the path to the local db (#506) + +### 📚 Documentation + +- Update category slug (#442) +- Review chroma integration (#501) + +## [integrations/chroma-v0.13.0] - 2024-02-13 + +## [integrations/chroma-v0.12.0] - 2024-02-06 + +### 🚀 Features + +- Generate API docs (#262) + +## [integrations/chroma-v0.11.0] - 2024-01-18 + +### 🐛 Bug Fixes + +- Chroma DocumentStore creation for pre-existing collection name (#157) + +## [integrations/chroma-v0.9.0] - 2023-12-20 + +### 🐛 Bug Fixes + +- Fix project urls (#96) + + + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + +## [integrations/chroma-v0.8.1] - 2023-12-05 + +### 🐛 Bug Fixes + +- Fix import and increase version (#77) + + + +## [integrations/chroma-v0.8.0] - 2023-12-04 + +### 🐛 Bug Fixes + +- Fix license headers + + + diff --git a/integrations/chroma/pydoc/config.yml b/integrations/chroma/pydoc/config.yml index c28902080..1e678b4cc 100644 --- a/integrations/chroma/pydoc/config.yml +++ b/integrations/chroma/pydoc/config.yml @@ -17,7 +17,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Chroma integration for Haystack category_slug: integrations-api title: Chroma diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 96ee2f9f8..b4591e041 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "John Doe", email = "jd@example.com" }, -] +authors = [{ name = "John Doe", email = "jd@example.com" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -24,11 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "chromadb<0.4.20", # FIXME: investigate why filtering tests broke on 0.4.20 - "typing_extensions>=4.8.0" -] +dependencies = ["haystack-ai", "chromadb>=0.5.0", "typing_extensions>=4.8.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" @@ -50,23 +44,18 @@ git_describe_command = 'git describe --tags --match="integrations/chroma-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", "databind-core<4.5.0", # FIXME: the latest 4.5.0 causes loops in pip resolver ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.9", "3.10"] @@ -77,23 +66,13 @@ dependencies = [ "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", - "numpy", # we need the stubs from the main package + "numpy", # we need the stubs from the main package ] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -139,9 +118,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", # Ignore unused params "ARG002", ] @@ -169,20 +154,13 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] minversion = "6.0" -markers = [ - "unit: unit tests", - "integration: integration tests" -] +markers = ["unit: unit tests", "integration: integration tests"] [[tool.mypy.overrides]] module = [ @@ -190,6 +168,6 @@ module = [ "haystack.*", "haystack_integrations.*", "pytest.*", - "numpy.*" + "numpy.*", ] ignore_missing_imports = true diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 7138eff88..71ac3457e 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -1,9 +1,11 @@ # SPDX-FileCopyrightText: 2023-present John Doe # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.chroma import ChromaDocumentStore @@ -41,27 +43,40 @@ class ChromaQueryTextRetriever: ``` """ - def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): + def __init__( + self, + document_store: ChromaDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): """ :param document_store: an instance of `ChromaDocumentStore`. :param filters: filters to narrow down the search space. :param top_k: the maximum number of documents to retrieve. + :param filter_policy: Policy to determine how filters are applied. """ - self.filters = filters + self.filters = filters or {} self.top_k = top_k self.document_store = document_store + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) @component.output_types(documents=List[Document]) def run( self, query: str, - _: Optional[Dict[str, Any]] = None, # filters not yet supported + filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, ): """ Run the retriever on the given input data. :param query: The input data for the retriever. In this case, a plain-text query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: The maximum number of documents to retrieve. If not specified, the default value from the constructor is used. :returns: A dictionary with the following keys: @@ -69,9 +84,9 @@ def run( :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) top_k = top_k or self.top_k - - return {"documents": self.document_store.search([query], top_k)[0]} + return {"documents": self.document_store.search([query], top_k, filters)[0]} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": @@ -85,6 +100,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": """ document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: @@ -98,6 +118,7 @@ def to_dict(self) -> Dict[str, Any]: self, filters=self.filters, top_k=self.top_k, + filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -119,9 +140,17 @@ def run( Run the retriever on the given input data. :param query_embedding: the query embeddings. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + If not specified, the default value from the constructor is used. + :returns: a dictionary with the following keys: - `documents`: List of documents returned by the search engine. """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k query_embeddings = [query_embedding] diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 6d795f8ca..3ea84780f 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import chromadb import numpy as np @@ -18,6 +18,10 @@ logger = logging.getLogger(__name__) +VALID_DISTANCE_FUNCTIONS = "l2", "cosine", "ip" +SUPPORTED_TYPES_FOR_METADATA_VALUES = str, int, float, bool + + class ChromaDocumentStore: """ A document store using [Chroma](https://docs.trychroma.com/) as the backend. @@ -31,6 +35,8 @@ def __init__( collection_name: str = "documents", embedding_function: str = "default", persist_path: Optional[str] = None, + distance_function: Literal["l2", "cosine", "ip"] = "l2", + metadata: Optional[dict] = None, **embedding_function_params, ): """ @@ -45,22 +51,59 @@ def __init__( :param collection_name: the name of the collection to use in the database. :param embedding_function: the name of the embedding function to use to embed the query :param persist_path: where to store the database. If None, the database will be `in-memory`. + :param distance_function: The distance metric for the embedding space. + - `"l2"` computes the Euclidean (straight-line) distance between vectors, + where smaller scores indicate more similarity. + - `"cosine"` computes the cosine similarity between vectors, + with higher scores indicating greater similarity. + - `"ip"` stands for inner product, where higher scores indicate greater similarity between vectors. + **Note**: `distance_function` can only be set during the creation of a collection. + To change the distance metric of an existing collection, consider cloning the collection. + :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client + method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the + `distance_function` parameter above. + :param embedding_function_params: additional parameters to pass to the embedding function. """ + + if distance_function not in VALID_DISTANCE_FUNCTIONS: + error_message = ( + f"Invalid distance_function: '{distance_function}' for the collection. " + f"Valid options are: {VALID_DISTANCE_FUNCTIONS}." + ) + raise ValueError(error_message) + # Store the params for marshalling self._collection_name = collection_name self._embedding_function = embedding_function self._embedding_function_params = embedding_function_params self._persist_path = persist_path + self._distance_function = distance_function # Create the client instance if persist_path is None: self._chroma_client = chromadb.Client() else: self._chroma_client = chromadb.PersistentClient(path=persist_path) - self._collection = self._chroma_client.get_or_create_collection( - name=collection_name, - embedding_function=get_embedding_function(embedding_function, **embedding_function_params), - ) + + embedding_func = get_embedding_function(embedding_function, **embedding_function_params) + + metadata = metadata or {} + if "hnsw:space" not in metadata: + metadata["hnsw:space"] = distance_function + + if collection_name in [c.name for c in self._chroma_client.list_collections()]: + self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) + + if metadata != self._collection.metadata: + logger.warning( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + ) + else: + self._collection = self._chroma_client.create_collection( + name=collection_name, + metadata=metadata, + embedding_function=embedding_func, + ) def count_documents(self) -> int: """ @@ -184,7 +227,26 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D data = {"ids": [doc.id], "documents": [doc.content]} if doc.meta: - data["metadatas"] = [doc.meta] + valid_meta = {} + discarded_keys = [] + + for k, v in doc.meta.items(): + if isinstance(v, SUPPORTED_TYPES_FOR_METADATA_VALUES): + valid_meta[k] = v + else: + discarded_keys.append(k) + + if discarded_keys: + logger.warning( + "Document %s contains `meta` values of unsupported types for the keys: %s. " + "These items will be discarded. Supported types are: %s.", + doc.id, + ", ".join(discarded_keys), + ", ".join([t.__name__ for t in SUPPORTED_TYPES_FOR_METADATA_VALUES]), + ) + + if valid_meta: + data["metadatas"] = [valid_meta] if doc.embedding is not None: data["embeddings"] = [doc.embedding] @@ -209,16 +271,30 @@ def delete_documents(self, document_ids: List[str]) -> None: """ self._collection.delete(ids=document_ids) - def search(self, queries: List[str], top_k: int) -> List[List[Document]]: + def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: """Search the documents in the store using the provided text queries. :param queries: the list of queries to search for. :param top_k: top_k documents to return for each query. + :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: matching documents for each query. """ - results = self._collection.query( - query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"] - ) + if filters is None: + results = self._collection.query( + query_texts=queries, + n_results=top_k, + include=["embeddings", "documents", "metadatas", "distances"], + ) + else: + chroma_filters = self._normalize_filters(filters=filters) + results = self._collection.query( + query_texts=queries, + n_results=top_k, + where=chroma_filters[1], + where_document=chroma_filters[2], + include=["embeddings", "documents", "metadatas", "distances"], + ) + return self._query_result_to_documents(results) def search_embeddings( @@ -276,6 +352,7 @@ def to_dict(self) -> Dict[str, Any]: collection_name=self._collection_name, embedding_function=self._embedding_function, persist_path=self._persist_path, + distance_function=self._distance_function, **self._embedding_function_params, ) @@ -380,8 +457,12 @@ def _query_result_to_documents(result: QueryResult) -> List[List[Document]]: } # prepare metadata - if metadatas := result.get("metadatas"): - document_dict["meta"] = dict(metadatas[i][j]) + metadatas = result.get("metadatas") + try: + if metadatas and metadatas[i][j] is not None: + document_dict["meta"] = metadatas[i][j] + except IndexError: + pass if embeddings := result.get("embeddings"): document_dict["embedding"] = np.array(embeddings[i][j]) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py index 08d6db618..5c31070ad 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py @@ -9,6 +9,7 @@ GoogleVertexEmbeddingFunction, HuggingFaceEmbeddingFunction, InstructorEmbeddingFunction, + OllamaEmbeddingFunction, ONNXMiniLM_L6_V2, OpenAIEmbeddingFunction, SentenceTransformerEmbeddingFunction, @@ -25,6 +26,7 @@ "GoogleVertexEmbeddingFunction": GoogleVertexEmbeddingFunction, "HuggingFaceEmbeddingFunction": HuggingFaceEmbeddingFunction, "InstructorEmbeddingFunction": InstructorEmbeddingFunction, + "OllamaEmbeddingFunction": OllamaEmbeddingFunction, "ONNXMiniLM_L6_V2": ONNXMiniLM_L6_V2, "OpenAIEmbeddingFunction": OpenAIEmbeddingFunction, "Text2VecEmbeddingFunction": Text2VecEmbeddingFunction, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 5b827a984..b05c9ccfc 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2023-present John Doe # # SPDX-License-Identifier: Apache-2.0 +import logging +import operator import uuid from typing import List from unittest import mock @@ -56,6 +58,9 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do This can happen for example when the Document Store sets a score to returned Documents. Since we can't know what the score will be, we can't compare the Documents reliably. """ + received.sort(key=operator.attrgetter("id")) + expected.sort(key=operator.attrgetter("id")) + for doc_received, doc_expected in zip(received, expected): assert doc_received.content == doc_expected.content assert doc_received.meta == doc_expected.meta @@ -87,6 +92,43 @@ def test_delete_not_empty_nonexisting(self, document_store: ChromaDocumentStore) assert document_store.filter_documents(filters={"id": doc.id}) == [doc] + def test_search(self): + document_store = ChromaDocumentStore() + documents = [ + Document(content="First document", meta={"author": "Author1"}), + Document(content="Second document"), # No metadata + Document(content="Third document", meta={"author": "Author2"}), + Document(content="Fourth document"), # No metadata + ] + document_store.write_documents(documents) + result = document_store.search(["Third"], top_k=1) + + # Assertions to verify correctness + assert len(result) == 1 + assert result[0][0].content == "Third document" + + def test_write_documents_unsupported_meta_values(self, document_store: ChromaDocumentStore): + """ + Unsupported meta values should be removed from the documents before writing them to the database + """ + + docs = [ + Document(content="test doc 1", meta={"invalid": {"dict": "value"}}), + Document(content="test doc 2", meta={"invalid": ["list", "value"]}), + Document(content="test doc 3", meta={"ok": 123}), + ] + + document_store.write_documents(docs) + + written_docs = document_store.filter_documents() + written_docs.sort(key=lambda x: x.content) + + assert len(written_docs) == 3 + assert [doc.id for doc in written_docs] == [doc.id for doc in docs] + assert written_docs[0].meta == {} + assert written_docs[1].meta == {} + assert written_docs[2].meta == {"ok": 123} + @pytest.mark.integration def test_to_json(self, request): ds = ChromaDocumentStore( @@ -100,6 +142,7 @@ def test_to_json(self, request): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, "api_key": "1234567890", + "distance_function": "l2", }, } @@ -114,6 +157,7 @@ def test_from_json(self): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, "api_key": "1234567890", + "distance_function": "l2", }, } @@ -124,12 +168,65 @@ def test_from_json(self): @pytest.mark.integration def test_same_collection_name_reinitialization(self): - ChromaDocumentStore("test_name") - ChromaDocumentStore("test_name") + ChromaDocumentStore("test_1") + ChromaDocumentStore("test_1") - @pytest.mark.skip(reason="Filter on array contents is not supported.") - def test_filter_document_array(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - pass + @pytest.mark.integration + def test_distance_metric_initialization(self): + store = ChromaDocumentStore("test_2", distance_function="cosine") + assert store._collection.metadata["hnsw:space"] == "cosine" + + with pytest.raises(ValueError): + ChromaDocumentStore("test_3", distance_function="jaccard") + + @pytest.mark.integration + def test_distance_metric_reinitialization(self, caplog): + store = ChromaDocumentStore("test_4", distance_function="cosine") + + with caplog.at_level(logging.WARNING): + new_store = ChromaDocumentStore("test_4", distance_function="ip") + + assert ( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + in caplog.text + ) + assert store._collection.metadata["hnsw:space"] == "cosine" + assert new_store._collection.metadata["hnsw:space"] == "cosine" + + @pytest.mark.integration + def test_metadata_initialization(self, caplog): + store = ChromaDocumentStore( + "test_5", + distance_function="cosine", + metadata={ + "hnsw:space": "ip", + "hnsw:search_ef": 101, + "hnsw:construction_ef": 102, + "hnsw:M": 103, + }, + ) + assert store._collection.metadata["hnsw:space"] == "ip" + assert store._collection.metadata["hnsw:search_ef"] == 101 + assert store._collection.metadata["hnsw:construction_ef"] == 102 + assert store._collection.metadata["hnsw:M"] == 103 + + with caplog.at_level(logging.WARNING): + new_store = ChromaDocumentStore( + "test_5", + metadata={ + "hnsw:space": "l2", + "hnsw:search_ef": 101, + "hnsw:construction_ef": 102, + "hnsw:M": 103, + }, + ) + + assert ( + "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + in caplog.text + ) + assert store._collection.metadata["hnsw:space"] == "ip" + assert new_store._collection.metadata["hnsw:space"] == "ip" @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") def test_filter_document_dataframe(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): @@ -143,10 +240,6 @@ def test_eq_filter_table(self, document_store: ChromaDocumentStore, filterable_d def test_eq_filter_embedding(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip(reason="$in operator is not supported.") - def test_in_filter_explicit(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - pass - @pytest.mark.skip(reason="$in operator is not supported. Filter on table contents is not supported.") def test_in_filter_table(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @@ -181,12 +274,6 @@ def test_filter_simple_implicit_and_with_multi_key_dict( ): pass - @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_simple_explicit_and_with_multikey_dict( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): - pass - @pytest.mark.skip(reason="Filter syntax not supported.") def test_filter_simple_explicit_and_with_list( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] @@ -197,10 +284,6 @@ def test_filter_simple_explicit_and_with_list( def test_filter_simple_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass - @pytest.mark.skip(reason="Filter syntax not supported.") - def test_filter_nested_explicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - pass - @pytest.mark.skip(reason="Filter syntax not supported.") def test_filter_nested_implicit_and(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): pass @@ -230,15 +313,3 @@ def test_filter_nested_multiple_identical_operators_same_level( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] ): pass - - @pytest.mark.skip(reason="Duplicate policy not supported.") - def test_write_duplicate_fail(self, document_store: ChromaDocumentStore): - pass - - @pytest.mark.skip(reason="Duplicate policy not supported.") - def test_write_duplicate_skip(self, document_store: ChromaDocumentStore): - pass - - @pytest.mark.skip(reason="Duplicate policy not supported.") - def test_write_duplicate_overwrite(self, document_store: ChromaDocumentStore): - pass diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index b430e5fda..f0e71828d 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -2,10 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 import pytest +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever from haystack_integrations.document_stores.chroma import ChromaDocumentStore +@pytest.mark.integration +def test_retriever_init(request): + ds = ChromaDocumentStore( + collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" + ) + retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace") + assert retriever.filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown") + + @pytest.mark.integration def test_retriever_to_json(request): ds = ChromaDocumentStore( @@ -17,6 +30,7 @@ def test_retriever_to_json(request): "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, + "filter_policy": "replace", "document_store": { "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", "init_parameters": { @@ -24,6 +38,7 @@ def test_retriever_to_json(request): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, "api_key": "1234567890", + "distance_function": "l2", }, }, }, @@ -37,6 +52,7 @@ def test_retriever_from_json(request): "init_parameters": { "filters": {"bar": "baz"}, "top_k": 42, + "filter_policy": "replace", "document_store": { "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", "init_parameters": { @@ -44,6 +60,36 @@ def test_retriever_from_json(request): "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": ".", "api_key": "1234567890", + "distance_function": "l2", + }, + }, + }, + } + retriever = ChromaQueryTextRetriever.from_dict(data) + assert retriever.document_store._collection_name == request.node.name + assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" + assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"} + assert retriever.document_store._persist_path == "." + assert retriever.filters == {"bar": "baz"} + assert retriever.top_k == 42 + assert retriever.filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.integration +def test_retriever_from_json_no_filter_policy(request): + data = { + "type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever", + "init_parameters": { + "filters": {"bar": "baz"}, + "top_k": 42, + "document_store": { + "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", + "init_parameters": { + "collection_name": "test_retriever_from_json_no_filter_policy", + "embedding_function": "HuggingFaceEmbeddingFunction", + "persist_path": ".", + "api_key": "1234567890", + "distance_function": "l2", }, }, }, @@ -55,3 +101,4 @@ def test_retriever_from_json(request): assert retriever.document_store._persist_path == "." assert retriever.filters == {"bar": "baz"} assert retriever.top_k == 42 + assert retriever.filter_policy == FilterPolicy.REPLACE # default even if not specified diff --git a/integrations/cohere/CHANGELOG.md b/integrations/cohere/CHANGELOG.md new file mode 100644 index 000000000..3067b0a5e --- /dev/null +++ b/integrations/cohere/CHANGELOG.md @@ -0,0 +1,108 @@ +# Changelog + +## [unreleased] + +### 🚀 Features + +- Update Anthropic/Cohere for tools use (#790) +- Update Cohere default LLMs, add examples and update unit tests (#838) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) + +## [integrations/cohere-v1.1.1] - 2024-06-12 + +## [integrations/cohere-v1.1.0] - 2024-05-24 + +### 🐛 Bug Fixes + +- Remove support for generate API (#755) + +## [integrations/cohere-v1.0.0] - 2024-05-03 + +## [integrations/cohere-v0.7.0] - 2024-05-02 + +## [integrations/cohere-v0.6.0] - 2024-04-08 + +### 🚀 Features + +- Add Cohere ranker (#643) + +## [integrations/cohere-v0.5.0] - 2024-03-29 + +## [integrations/cohere-v0.4.1] - 2024-03-21 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Fix tests (#561) + +* fix unit tests + +* try + +* remove flaky check + +### 📚 Documentation + +- Update category slug (#442) +- Review cohere integration (#500) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Update Cohere integration to use new generic callable (de)serializers for their callback handlers (#453) +- Use `serialize_callable` instead of `serialize_callback_handler` in Cohere (#460) + +### Cohere + +- Fix linting (#509) + +## [integrations/cohere-v0.4.0] - 2024-02-12 + +### 🐛 Bug Fixes + +- Fix Cohere tests (#337) +- Cohere inconsistent embeddings and documents lengths (#284) + +### 🚜 Refactor + +- [**breaking**] Use `Secret` for API keys in Cohere components (#386) + +### 🧪 Testing + +- Fix failing `TestCohereChatGenerator.test_from_dict_fail_wo_env_var` test (#393) + +## [integrations/cohere-v0.3.0] - 2024-01-25 + +### 🐛 Bug Fixes + +- Fix project urls (#96) + + +- Cohere namespace reorg (#271) + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + +### ⚙️ Miscellaneous Tasks + +- [**breaking**] Rename `model_name` to `model` in the Cohere integration (#222) +- Cohere namespace change (#247) + +## [integrations/cohere-v0.2.0] - 2023-12-11 + +### 🚀 Features + +- Add support for V3 Embed models to CohereEmbedders (#89) + +## [integrations/cohere-v0.1.1] - 2023-12-07 + +## [integrations/cohere-v0.0.1] - 2023-12-04 + + diff --git a/integrations/cohere/examples/cohere_embedding.py b/integrations/cohere/examples/cohere_embedding.py new file mode 100644 index 000000000..e6fe3cc35 --- /dev/null +++ b/integrations/cohere/examples/cohere_embedding.py @@ -0,0 +1,28 @@ +from haystack import Document, Pipeline +from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.embedders.cohere.document_embedder import CohereDocumentEmbedder +from haystack_integrations.components.embedders.cohere.text_embedder import CohereTextEmbedder + +document_store = InMemoryDocumentStore(embedding_similarity_function="cosine") + +documents = [ + Document(content="My name is Wolfgang and I live in Berlin"), + Document(content="I saw a black horse running"), + Document(content="Germany has many big cities"), +] + +document_embedder = CohereDocumentEmbedder() +documents_with_embeddings = document_embedder.run(documents)["documents"] +document_store.write_documents(documents_with_embeddings) + +query_pipeline = Pipeline() +query_pipeline.add_component("text_embedder", CohereTextEmbedder()) +query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store)) +query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + +query = "Who lives in Berlin?" + +result = query_pipeline.run({"text_embedder": {"text": query}}) + +print(result["retriever"]["documents"][0]) # noqa: T201 diff --git a/integrations/cohere/examples/cohere_generation.py b/integrations/cohere/examples/cohere_generation.py new file mode 100644 index 000000000..cd79e37d3 --- /dev/null +++ b/integrations/cohere/examples/cohere_generation.py @@ -0,0 +1,53 @@ +# This example demonstrates a corrective Haystack pipeline with Cohere LLM integration that runs until the +# generated output satisfies a strict JSON schema. +# +# The pipeline includes the following components: +# - BranchJoiner: https://docs.haystack.deepset.ai/reference/joiners-api#branchjoiner +# - JsonSchemaValidator: https://docs.haystack.deepset.ai/reference/validators-api#jsonschemavalidator +# +# The pipeline workflow: +# 1. Receives a user message requesting to create a JSON object from "Peter Parker" aka Superman. +# 2. Processes the message through components to generate a response using Cohere command-r model. +# 3. Validates the generated response against a predefined JSON schema for person data. +# 4. If the response does not meet the schema, the JsonSchemaValidator provides details on how to correct the errors. +# 4a. The pipeline loops back, using the error information to generate a new JSON object until it satisfies the schema. +# 5. If the response is validated against the schema, outputs the validated JSON object. + +from typing import List + +from haystack import Pipeline +from haystack.components.converters import OutputAdapter +from haystack.components.joiners import BranchJoiner +from haystack.components.validators import JsonSchemaValidator +from haystack.dataclasses import ChatMessage +from haystack_integrations.components.generators.cohere import CohereChatGenerator + +# Defines a JSON schema for validating a person's data. The schema specifies that a valid object must +# have first_name, last_name, and nationality properties, with specific constraints on their values. +person_schema = { + "type": "object", + "properties": { + "first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]}, + }, + "required": ["first_name", "last_name", "nationality"], +} + +# Initialize a pipeline +pipe = Pipeline() + +# Add components to the pipeline +pipe.add_component("joiner", BranchJoiner(List[ChatMessage])) +pipe.add_component("fc_llm", CohereChatGenerator(model="command-r")) +pipe.add_component("validator", JsonSchemaValidator(json_schema=person_schema)) +pipe.add_component("adapter", OutputAdapter("{{chat_message}}", List[ChatMessage])), +# And connect them +pipe.connect("adapter", "joiner") +pipe.connect("joiner", "fc_llm") +pipe.connect("fc_llm.replies", "validator.messages") +pipe.connect("validator.validation_error", "joiner") + +result = pipe.run(data={"adapter": {"chat_message": [ChatMessage.from_user("Create json from Peter Parker")]}}) + +print(result["validator"]["validated"]) # noqa: T201 diff --git a/integrations/cohere/examples/cohere_ranker_in_a_pipeline.py b/integrations/cohere/examples/cohere_ranker.py similarity index 81% rename from integrations/cohere/examples/cohere_ranker_in_a_pipeline.py rename to integrations/cohere/examples/cohere_ranker.py index 2234eb3a6..79a3d346d 100644 --- a/integrations/cohere/examples/cohere_ranker_in_a_pipeline.py +++ b/integrations/cohere/examples/cohere_ranker.py @@ -15,7 +15,7 @@ document_store.write_documents(docs) retriever = InMemoryBM25Retriever(document_store=document_store) -ranker = CohereRanker(model="rerank-english-v2.0", top_k=3) +ranker = CohereRanker(model="rerank-english-v2.0") document_ranker_pipeline = Pipeline() document_ranker_pipeline.add_component(instance=retriever, name="retriever") @@ -24,6 +24,5 @@ document_ranker_pipeline.connect("retriever.documents", "ranker.documents") query = "Cities in France" -res = document_ranker_pipeline.run( - data={"retriever": {"query": query, "top_k": 3}, "ranker": {"query": query, "top_k": 3}} -) +res = document_ranker_pipeline.run(data={"retriever": {"query": query}, "ranker": {"query": query, "top_k": 2}}) +print(res["ranker"]["documents"]) # noqa: T201 diff --git a/integrations/cohere/pydoc/config.yml b/integrations/cohere/pydoc/config.yml index 5d4e747f5..53c54b664 100644 --- a/integrations/cohere/pydoc/config.yml +++ b/integrations/cohere/pydoc/config.yml @@ -19,7 +19,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Cohere integration for Haystack category_slug: integrations-api title: Cohere diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index a3ad97582..04fe15585 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -22,10 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "cohere<5", -] +dependencies = ["haystack-ai", "cohere==5.*"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme" @@ -44,12 +41,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -60,7 +59,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -138,12 +137,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] @@ -164,6 +159,6 @@ markers = [ "embedders: embedders tests", "generators: generators tests", "chat_generators: chat_generators tests", - "ranker: ranker tests" + "ranker: ranker tests", ] log_cli = true diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 944a4746f..59a04cf3c 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -8,7 +8,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response -from cohere import COHERE_API_URL, AsyncClient, Client +from cohere import AsyncClient, Client @component @@ -39,10 +39,9 @@ def __init__( api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), model: str = "embed-english-v2.0", input_type: str = "search_document", - api_base_url: str = COHERE_API_URL, + api_base_url: str = "https://api.cohere.com", truncate: str = "END", use_async_client: bool = False, - max_retries: int = 3, timeout: int = 120, batch_size: int = 32, progress_bar: bool = True, @@ -67,7 +66,6 @@ def __init__( If "NONE" is selected, when the input exceeds the maximum input token length an error will be returned. :param use_async_client: flag to select the AsyncClient. It is recommended to use AsyncClient for applications with many concurrent calls. - :param max_retries: maximal number of retries for requests. :param timeout: request timeout in seconds. :param batch_size: number of Documents to encode at once. :param progress_bar: whether to show a progress bar or not. Can be helpful to disable in production deployments @@ -82,7 +80,6 @@ def __init__( self.api_base_url = api_base_url self.truncate = truncate self.use_async_client = use_async_client - self.max_retries = max_retries self.timeout = timeout self.batch_size = batch_size self.progress_bar = progress_bar @@ -104,7 +101,6 @@ def to_dict(self) -> Dict[str, Any]: api_base_url=self.api_base_url, truncate=self.truncate, use_async_client=self.use_async_client, - max_retries=self.max_retries, timeout=self.timeout, batch_size=self.batch_size, progress_bar=self.progress_bar, @@ -169,8 +165,7 @@ def run(self, documents: List[Document]): if self.use_async_client: cohere_client = AsyncClient( api_key, - api_url=self.api_base_url, - max_retries=self.max_retries, + base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) @@ -180,8 +175,7 @@ def run(self, documents: List[Document]): else: cohere_client = Client( api_key, - api_url=self.api_base_url, - max_retries=self.max_retries, + base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index 95355cbe8..80ede51bf 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -8,7 +8,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response -from cohere import COHERE_API_URL, AsyncClient, Client +from cohere import AsyncClient, Client @component @@ -36,10 +36,9 @@ def __init__( api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), model: str = "embed-english-v2.0", input_type: str = "search_query", - api_base_url: str = COHERE_API_URL, + api_base_url: str = "https://api.cohere.com", truncate: str = "END", use_async_client: bool = False, - max_retries: int = 3, timeout: int = 120, ): """ @@ -60,7 +59,6 @@ def __init__( If "NONE" is selected, when the input exceeds the maximum input token length an error will be returned. :param use_async_client: flag to select the AsyncClient. It is recommended to use AsyncClient for applications with many concurrent calls. - :param max_retries: maximum number of retries for requests. :param timeout: request timeout in seconds. """ @@ -70,7 +68,6 @@ def __init__( self.api_base_url = api_base_url self.truncate = truncate self.use_async_client = use_async_client - self.max_retries = max_retries self.timeout = timeout def to_dict(self) -> Dict[str, Any]: @@ -88,7 +85,6 @@ def to_dict(self) -> Dict[str, Any]: api_base_url=self.api_base_url, truncate=self.truncate, use_async_client=self.use_async_client, - max_retries=self.max_retries, timeout=self.timeout, ) @@ -131,8 +127,7 @@ def run(self, text: str): if self.use_async_client: cohere_client = AsyncClient( api_key, - api_url=self.api_base_url, - max_retries=self.max_retries, + base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) @@ -142,8 +137,7 @@ def run(self, text: str): else: cohere_client = Client( api_key, - api_url=self.api_base_url, - max_retries=self.max_retries, + base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index 21a65e3da..a5c20cb35 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -5,7 +5,7 @@ from tqdm import tqdm -from cohere import AsyncClient, Client, CohereError +from cohere import AsyncClient, Client async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): @@ -24,20 +24,14 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], """ all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} - try: - response = await cohere_async_client.embed( - texts=texts, model=model_name, input_type=input_type, truncate=truncate - ) - if response.meta is not None: - metadata = response.meta - for emb in response.embeddings: - all_embeddings.append(emb) - return all_embeddings, metadata + response = await cohere_async_client.embed(texts=texts, model=model_name, input_type=input_type, truncate=truncate) + if response.meta is not None: + metadata = response.meta + for emb in response.embeddings: + all_embeddings.append(emb) - except CohereError as error_response: - msg = error_response.message - raise ValueError(msg) from error_response + return all_embeddings, metadata def get_response( @@ -62,21 +56,16 @@ def get_response( all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} - try: - for i in tqdm( - range(0, len(texts), batch_size), - disable=not progress_bar, - desc="Calculating embeddings", - ): - batch = texts[i : i + batch_size] - response = cohere_client.embed(batch, model=model_name, input_type=input_type, truncate=truncate) - for emb in response.embeddings: - all_embeddings.append(emb) - if response.meta is not None: - metadata = response.meta - - return all_embeddings, metadata - - except CohereError as error_response: - msg = error_response.message - raise ValueError(msg) from error_response + for i in tqdm( + range(0, len(texts), batch_size), + disable=not progress_bar, + desc="Calculating embeddings", + ): + batch = texts[i : i + batch_size] + response = cohere_client.embed(texts=batch, model=model_name, input_type=input_type, truncate=truncate) + for emb in response.embeddings: + all_embeddings.append(emb) + if response.meta is not None: + metadata = response.meta + + return all_embeddings, metadata diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 2df564fce..568a26979 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -15,18 +15,16 @@ @component class CohereChatGenerator: """ - Enables text generation using Cohere's chat endpoint. + Completes chats using Cohere's models through Cohere `chat` endpoint. - This component is designed to inference Cohere's chat models. + You can customize how the text is generated by passing parameters to the + Cohere API through the `**generation_kwargs` parameter. You can do this when + initializing or running the component. Any parameter that works with + `cohere.Client.chat` will work here too. + For details, see [Cohere API](https://docs.cohere.com/reference/chat). - Users can pass any text generation parameters valid for the `cohere.Client,chat` method - directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` - parameter in `run` method. + ### Usage example - Invocations are made using 'cohere' package. - See [Cohere API](https://docs.cohere.com/reference/chat) for more details. - - Example usage: ```python from haystack_integrations.components.generators.cohere import CohereChatGenerator @@ -40,7 +38,7 @@ class CohereChatGenerator: def __init__( self, api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), - model: str = "command", + model: str = "command-r", streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, @@ -49,40 +47,41 @@ def __init__( """ Initialize the CohereChatGenerator instance. - :param api_key: the API key for the Cohere API. - :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, - command-nightly-light]. - :param streaming_callback: a callback function to be called with the streaming response. - :param api_base_url: the base URL of the Cohere API. - :param generation_kwargs: additional model parameters. These will be used during generation. Refer to - https://docs.cohere.com/reference/chat for more details. + :param api_key: The API key for the Cohere API. + :param model: The name of the model to use. You can use models from the `command` family. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) + as an argument. + :param api_base_url: The base URL of the Cohere API. + :param generation_kwargs: Other parameters to use for the model during generation. For a list of parameters, + see [Cohere Chat endpoint](https://docs.cohere.com/reference/chat). Some of the parameters are: - 'chat_history': A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message. - - 'preamble_override': When specified, the default Cohere preamble will be replaced with the provided one. - - 'conversation_id': An alternative to chat_history. Previous conversations can be resumed by providing - the conversation's identifier. The contents of message and the model's response will be stored - as part of this conversation.If a conversation with this id does not already exist, - a new conversation will be created. - - 'prompt_truncation': Defaults to AUTO when connectors are specified and OFF in all other cases. - Dictates how the prompt will be constructed. - - 'connectors': Accepts {"id": "web-search"}, and/or the "id" for a custom connector, if you've created one. - When specified, the model's reply will be enriched with information found by + - 'preamble': When specified, replaces the default Cohere preamble with the provided one. + - 'conversation_id': An alternative to `chat_history`. Previous conversations can be resumed by providing + the conversation's identifier. The contents of message and the model's response are stored + as part of this conversation. If a conversation with this ID doesn't exist, + a new conversation is created. + - 'prompt_truncation': Defaults to `AUTO` when connectors are specified and to `OFF` in all other cases. + Dictates how the prompt is constructed. + - 'connectors': Accepts {"id": "web-search"}, and the "id" for a custom connector, if you created one. + When specified, the model's reply is enriched with information found by quering each of the connectors (RAG). - 'documents': A list of relevant documents that the model can use to enrich its reply. - - 'search_queries_only': Defaults to false. When true, the response will only contain a - list of generated search queries, but no search will take place, and no reply from the model to the - user's message will be generated. - - 'citation_quality': Defaults to "accurate". Dictates the approach taken to generating citations + - 'search_queries_only': Defaults to `False`. When `True`, the response only contains a + list of generated search queries, but no search takes place, and no reply from the model to the + user's message is generated. + - 'citation_quality': Defaults to `accurate`. Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want - "accurate" results or "fast" results. + `accurate` results or `fast` results. - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. """ cohere_import.check() if not api_base_url: - api_base_url = cohere.COHERE_API_URL + api_base_url = "https://api.cohere.com" if generation_kwargs is None: generation_kwargs = {} self.api_key = api_key @@ -92,7 +91,7 @@ def __init__( self.generation_kwargs = generation_kwargs self.model_parameters = kwargs self.client = cohere.Client( - api_key=self.api_key.resolve_value(), api_url=self.api_base_url, client_name="haystack" + api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -156,30 +155,46 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} chat_history = [self._message_to_dict(m) for m in messages[:-1]] - response = self.client.chat( - message=messages[-1].content, - model=self.model, - stream=self.streaming_callback is not None, - chat_history=chat_history, - **generation_kwargs, - ) if self.streaming_callback: - for chunk in response: - if chunk.event_type == "text-generation": - stream_chunk = self._build_chunk(chunk) - self.streaming_callback(stream_chunk) - chat_message = ChatMessage.from_assistant(content=response.texts) - chat_message.meta.update( - { - "model": self.model, - "usage": response.token_count, - "index": 0, - "finish_reason": response.finish_reason, - "documents": response.documents, - "citations": response.citations, - } + response = self.client.chat_stream( + message=messages[-1].content, + model=self.model, + chat_history=chat_history, + **generation_kwargs, ) + + response_text = "" + finish_response = None + for event in response: + if event.event_type == "text-generation": + stream_chunk = self._build_chunk(event) + self.streaming_callback(stream_chunk) + response_text += event.text + elif event.event_type == "stream-end": + finish_response = event.response + chat_message = ChatMessage.from_assistant(content=response_text) + + if finish_response and finish_response.meta: + if finish_response.meta.billed_units: + tokens_in = finish_response.meta.billed_units.input_tokens or -1 + tokens_out = finish_response.meta.billed_units.output_tokens or -1 + chat_message.meta["usage"] = tokens_in + tokens_out + chat_message.meta.update( + { + "model": self.model, + "index": 0, + "finish_reason": finish_response.finish_reason, + "documents": finish_response.documents, + "citations": finish_response.citations, + } + ) else: + response = self.client.chat( + message=messages[-1].content, + model=self.model, + chat_history=chat_history, + **generation_kwargs, + ) chat_message = self._build_message(response) return {"replies": [chat_message]} @@ -190,7 +205,7 @@ def _build_chunk(self, chunk) -> StreamingChunk: :param choice: The choice returned by the OpenAI API. :returns: The StreamingChunk. """ - chat_message = StreamingChunk(content=chunk.text, meta={"index": chunk.index, "event_type": chunk.event_type}) + chat_message = StreamingChunk(content=chunk.text, meta={"event_type": chunk.event_type}) return chat_message def _build_message(self, cohere_response): @@ -199,14 +214,19 @@ def _build_message(self, cohere_response): :param cohere_response: The completion returned by the Cohere API. :returns: The ChatMessage. """ - content = cohere_response.text - message = ChatMessage.from_assistant(content=content) + message = None + if cohere_response.tool_calls: + # TODO revisit to see if we need to handle multiple tool calls + message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) + elif cohere_response.text: + message = ChatMessage.from_assistant(content=cohere_response.text) + total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens message.meta.update( { "model": self.model, - "usage": cohere_response.token_count, + "usage": total_tokens, "index": 0, - "finish_reason": None, + "finish_reason": cohere_response.finish_reason, "documents": cohere_response.documents, "citations": cohere_response.citations, } diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index dd98b3c4b..0eb65b368 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -2,27 +2,25 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Dict, List, Optional -from haystack import component, default_from_dict, default_to_dict -from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler -from haystack.dataclasses import StreamingChunk -from haystack.utils import Secret, deserialize_secrets_inplace +from haystack import component +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.utils import Secret -from cohere import COHERE_API_URL, Client -from cohere.responses import Generations +from .chat.chat_generator import CohereChatGenerator logger = logging.getLogger(__name__) @component -class CohereGenerator: - """LLM Generator compatible with Cohere's generate endpoint. +class CohereGenerator(CohereChatGenerator): + """Generates text using Cohere's models through Cohere's `generate` endpoint. - Queries the LLM using Cohere's API. Invocations are made using 'cohere' package. - See [Cohere API](https://docs.cohere.com/reference/generate) for more details. + NOTE: Cohere discontinued the `generate` API, so this generator is a mere wrapper + around `CohereChatGenerator` provided for backward compatibility. - Example usage: + ### Usage example ```python from haystack_integrations.components.generators.cohere import CohereGenerator @@ -35,7 +33,7 @@ class CohereGenerator: def __init__( self, api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), - model: str = "command", + model: str = "command-r", streaming_callback: Optional[Callable] = None, api_base_url: Optional[str] = None, **kwargs, @@ -43,82 +41,18 @@ def __init__( """ Instantiates a `CohereGenerator` component. - :param api_key: the API key for the Cohere API. - :param model: the name of the model to use. Available models are: [command, command-light, command-nightly, - command-nightly-light]. - :param streaming_callback: A callback function to be called with the streaming response. - :param api_base_url: the base URL of the Cohere API. - :param kwargs: additional model parameters. These will be used during generation. Refer to - https://docs.cohere.com/reference/generate for more details. - Some of the parameters are: - - 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024. - - 'truncate': One of NONE|START|END to specify how the API will handle inputs longer than the maximum token - length. Defaults to END. - - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures - mean less random generations. - - 'preset': Identifier of a custom preset. A preset is a combination of parameters, such as prompt, - temperature etc. You can create presets in the playground. - - 'end_sequences': The generated text will be cut at the beginning of the earliest occurrence of an end - sequence. The sequence will be excluded from the text. - - 'stop_sequences': The generated text will be cut at the end of the earliest occurrence of a stop sequence. - The sequence will be included the text. - - 'k': Defaults to 0, min value of 0.01, max value of 0.99. - - 'p': Ensures that only the most likely tokens, with total probability mass of `p`, are considered for - generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. - - 'frequency_penalty': Used to reduce repetitiveness of generated tokens. The higher the value, the stronger - a penalty is applied to previously present tokens, proportional to how many times they have already - appeared in the prompt or prior generation.' - - 'presence_penalty': Defaults to 0.0, min value of 0.0, max value of 1.0. Can be used to reduce - repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied - equally to all tokens that have already appeared, regardless of their exact frequencies. - - 'return_likelihoods': One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned - with the response. Defaults to NONE. - - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include - desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. + :param api_key: Cohere API key. + :param model: Cohere model to use for generation. + :param streaming_callback: Callback function that is called when a new token is received from the stream. + The callback function accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) + as an argument. + :param api_base_url: Cohere base URL. + :param **kwargs: Additional arguments passed to the model. These arguments are specific to the model. + You can check them in model's documentation. """ - if not api_base_url: - api_base_url = COHERE_API_URL - self.api_key = api_key - self.model = model - self.streaming_callback = streaming_callback - self.api_base_url = api_base_url - self.model_parameters = kwargs - self.client = Client(api_key=self.api_key.resolve_value(), api_url=self.api_base_url, client_name="haystack") - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes the component to a dictionary. - - :returns: - Dictionary with serialized data. - """ - return default_to_dict( - self, - model=self.model, - streaming_callback=serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None, - api_base_url=self.api_base_url, - api_key=self.api_key.to_dict(), - **self.model_parameters, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": - """ - Deserializes the component from a dictionary. - - :param data: - Dictionary to deserialize from. - :returns: - Deserialized component. - """ - init_params = data.get("init_parameters", {}) - deserialize_secrets_inplace(init_params, ["api_key"]) - if "streaming_callback" in init_params and init_params["streaming_callback"] is not None: - data["init_parameters"]["streaming_callback"] = deserialize_callback_handler( - init_params["streaming_callback"] - ) - return default_from_dict(cls, data) + # Note we have to call super() like this because of the way components are dynamically built with the decorator + super(CohereGenerator, self).__init__(api_key, model, streaming_callback, api_base_url, None, **kwargs) # noqa @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) def run(self, prompt: str): @@ -127,45 +61,10 @@ def run(self, prompt: str): :param prompt: the prompt to be sent to the generative model. :returns: A dictionary with the following keys: - - `replies`: the list of replies generated by the model. - - `meta`: metadata about the request. - """ - response = self.client.generate( - model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters - ) - if self.streaming_callback: - metadata_dict: Dict[str, Any] = {} - for chunk in response: - stream_chunk = self._build_chunk(chunk) - self.streaming_callback(stream_chunk) - replies = response.texts - metadata_dict["finish_reason"] = response.finish_reason - metadata = [metadata_dict] - self._check_truncated_answers(metadata) - return {"replies": replies, "meta": metadata} - - metadata = [{"finish_reason": resp.finish_reason} for resp in cast(Generations, response)] - replies = [resp.text for resp in response] - self._check_truncated_answers(metadata) - return {"replies": replies, "meta": metadata} - - def _build_chunk(self, chunk) -> StreamingChunk: - """ - Converts the response from the Cohere API to a StreamingChunk. - :param chunk: The chunk returned by the OpenAI API. - :returns: The StreamingChunk. - """ - streaming_chunk = StreamingChunk(content=chunk.text, meta={"index": chunk.index}) - return streaming_chunk - - def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): - """ - Check the `finish_reason` returned with the Cohere response. - If the `finish_reason` is `MAX_TOKEN`, log a warning to the user. - :param metadata: The metadata returned by the Cohere API. + - `replies`: A list of replies generated by the model. + - `meta`: Information about the request. """ - if metadata[0]["finish_reason"] == "MAX_TOKENS": - logger.warning( - "Responses have been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) + chat_message = ChatMessage(content=prompt, role=ChatRole.USER, name="", meta={}) + # Note we have to call super() like this because of the way components are dynamically built with the decorator + results = super(CohereGenerator, self).run([chat_message]) # noqa + return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]} diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py index f902a286c..7da823bbc 100644 --- a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py @@ -36,7 +36,7 @@ def __init__( model: str = "rerank-english-v2.0", top_k: int = 10, api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), - api_base_url: str = cohere.COHERE_API_URL, + api_base_url: str = "https://api.cohere.com", max_chunks_per_doc: Optional[int] = None, meta_fields_to_embed: Optional[List[str]] = None, meta_data_separator: str = "\n", @@ -66,7 +66,7 @@ def __init__( self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_data_separator = meta_data_separator self._cohere_client = cohere.Client( - api_key=self.api_key.resolve_value(), api_url=self.api_base_url, client_name="haystack" + api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) def to_dict(self) -> Dict[str, Any]: diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 1684046eb..6521503f2 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -1,8 +1,9 @@ +import json import os -from unittest.mock import Mock, patch +from unittest.mock import Mock -import cohere import pytest +from cohere.core import ApiError from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.utils import Secret @@ -11,30 +12,6 @@ pytestmark = pytest.mark.chat_generators -@pytest.fixture -def mock_chat_response(): - """ - Mock the CohereI API response and reuse it for tests - """ - with patch("cohere.Client.chat", autospec=True) as mock_chat_response: - # mimic the response from the Cohere API - - mock_response = Mock() - mock_response.text = "I'm fine, thanks." - mock_response.token_count = { - "prompt_tokens": 66, - "response_tokens": 78, - "total_tokens": 144, - "billed_tokens": 133, - } - mock_response.meta = { - "api_version": {"version": "1"}, - "billed_units": {"input_tokens": 55, "output_tokens": 78}, - } - mock_chat_response.return_value = mock_response - yield mock_chat_response - - def streaming_chunk(text: str): """ Mock chunks of streaming responses from the Cohere API @@ -58,9 +35,9 @@ def test_init_default(self, monkeypatch): component = CohereChatGenerator() assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) - assert component.model == "command" + assert component.model == "command-r" assert component.streaming_callback is None - assert component.api_base_url == cohere.COHERE_API_URL + assert component.api_base_url == "https://api.cohere.com" assert not component.generation_kwargs def test_init_fail_wo_api_key(self, monkeypatch): @@ -90,10 +67,10 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command", + "model": "command-r", "streaming_callback": None, "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, - "api_base_url": "https://api.cohere.ai", + "api_base_url": "https://api.cohere.com", "generation_kwargs": {}, }, } @@ -123,7 +100,7 @@ def test_to_dict_with_parameters(self, monkeypatch): def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereChatGenerator( - model="command", + model="command-r", streaming_callback=lambda x: x, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, @@ -132,7 +109,7 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command", + "model": "command-r", "api_base_url": "test-base-url", "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "tests.test_cohere_chat_generator.", @@ -146,7 +123,7 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command", + "model": "command-r", "api_base_url": "test-base-url", "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -154,7 +131,7 @@ def test_from_dict(self, monkeypatch): }, } component = CohereChatGenerator.from_dict(data) - assert component.model == "command" + assert component.model == "command-r" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @@ -165,7 +142,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { - "model": "command", + "model": "command-r", "api_base_url": "test-base-url", "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -175,73 +152,11 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(ValueError): CohereChatGenerator.from_dict(data) - def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 - component = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) - response = component.run(chat_messages) - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_message_to_dict(self, chat_messages): obj = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) dictionary = [obj._message_to_dict(message) for message in chat_messages] assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] - def test_run_with_params(self, chat_messages, mock_chat_response): - component = CohereChatGenerator( - api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} - ) - response = component.run(chat_messages) - - # check that the component calls the Cohere API with the correct parameters - _, kwargs = mock_chat_response.call_args - assert kwargs["max_tokens"] == 10 - assert kwargs["temperature"] == 0.5 - - # check that the component returns the correct response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - def test_run_streaming(self, chat_messages, mock_chat_response): - streaming_call_count = 0 - - # Define the streaming callback function and assert that it is called with StreamingChunk objects - def streaming_callback_fn(chunk: StreamingChunk): - nonlocal streaming_call_count - streaming_call_count += 1 - assert isinstance(chunk, StreamingChunk) - - generator = CohereChatGenerator( - api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback_fn - ) - - # Create a fake streamed response - # self needed here, don't remove - def mock_iter(self): # noqa: ARG001 - yield streaming_chunk("Hello") - yield streaming_chunk("How are you?") - - mock_response = Mock(**{"__iter__": mock_iter}) - mock_chat_response.return_value = mock_response - - response = generator.run(chat_messages) - - # Assert that the streaming callback was called twice - assert streaming_call_count == 2 - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", @@ -262,7 +177,7 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong") - with pytest.raises(cohere.CohereAPIError): + with pytest.raises(ApiError): component.run(chat_messages) @pytest.mark.skipif( @@ -288,7 +203,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content[0] + assert "Paris" in message.content assert message.meta["finish_reason"] == "COMPLETE" @@ -332,7 +247,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content[0] + assert "Paris" in message.content assert message.meta["finish_reason"] == "COMPLETE" @@ -340,3 +255,40 @@ def __call__(self, chunk: StreamingChunk) -> None: assert message.meta["documents"] is not None assert message.meta["citations"] is not None + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_tools_use(self): + # See https://docs.anthropic.com/en/docs/tool-use for more information + tools_schema = { + "name": "get_stock_price", + "description": "Retrieves the current stock price for a given ticker symbol.", + "parameter_definitions": { + "ticker": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.", + "required": True, + } + }, + } + client = CohereChatGenerator(model="command-r") + response = client.run( + messages=[ChatMessage.from_user("What is the current price of AAPL?")], + generation_kwargs={"tools": [tools_schema]}, + ) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" + assert first_reply.meta, "First reply has no metadata" + fc_response = json.loads(first_reply.content) + assert "name" in fc_response, "First reply does not contain name of the tool" + assert "parameters" in fc_response, "First reply does not contain parameters of the tool" diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generator.py similarity index 86% rename from integrations/cohere/tests/test_cohere_generators.py rename to integrations/cohere/tests/test_cohere_generator.py index 32fdfce50..736b6bfbf 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generator.py @@ -4,12 +4,13 @@ import os import pytest -from cohere import COHERE_API_URL +from cohere.core import ApiError from haystack.components.generators.utils import print_streaming_chunk from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereGenerator pytestmark = pytest.mark.generators +COHERE_API_URL = "https://api.cohere.com" class TestCohereGenerator: @@ -17,7 +18,7 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "foo") component = CohereGenerator() assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) - assert component.model == "command" + assert component.model == "command-r" assert component.streaming_callback is None assert component.api_base_url == COHERE_API_URL assert component.model_parameters == {} @@ -45,10 +46,11 @@ def test_to_dict_default(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { - "model": "command", + "model": "command-r", "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": None, "api_base_url": COHERE_API_URL, + "generation_kwargs": {}, }, } @@ -68,18 +70,17 @@ def test_to_dict_with_parameters(self, monkeypatch): "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command-light", - "max_tokens": 10, - "some_test_param": "test-params", "api_base_url": "test-base-url", "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {}, }, } def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereGenerator( - model="command", + model="command-r", max_tokens=10, some_test_param="test-params", streaming_callback=lambda x: x, @@ -89,12 +90,11 @@ def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { - "model": "command", - "streaming_callback": "tests.test_cohere_generators.", - "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, + "model": "command-r", + "streaming_callback": "tests.test_cohere_generator.", "api_base_url": "test-base-url", - "max_tokens": 10, - "some_test_param": "test-params", + "api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True}, + "generation_kwargs": {}, }, } @@ -104,7 +104,7 @@ def test_from_dict(self, monkeypatch): data = { "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { - "model": "command", + "model": "command-r", "max_tokens": 10, "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "some_test_param": "test-params", @@ -114,20 +114,11 @@ def test_from_dict(self, monkeypatch): } component: CohereGenerator = CohereGenerator.from_dict(data) assert component.api_key == Secret.from_env_var("ENV_VAR", strict=False) - assert component.model == "command" + assert component.model == "command-r" assert component.streaming_callback == print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} - def test_check_truncated_answers(self, caplog): - component = CohereGenerator(api_key=Secret.from_token("test-api-key")) - meta = [{"finish_reason": "MAX_TOKENS"}] - component._check_truncated_answers(meta) - assert caplog.records[0].message == ( - "Responses have been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) - @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", @@ -147,10 +138,8 @@ def test_cohere_generator_run(self): ) @pytest.mark.integration def test_cohere_generator_run_wrong_model(self): - import cohere - component = CohereGenerator(model="something-obviously-wrong") - with pytest.raises(cohere.CohereAPIError): + with pytest.raises(ApiError): component.run(prompt="What's the capital of France?") @pytest.mark.skipif( diff --git a/integrations/cohere/tests/test_cohere_ranker.py b/integrations/cohere/tests/test_cohere_ranker.py index 08e01c647..670e662d4 100644 --- a/integrations/cohere/tests/test_cohere_ranker.py +++ b/integrations/cohere/tests/test_cohere_ranker.py @@ -2,12 +2,12 @@ from unittest.mock import Mock, patch import pytest -from cohere import COHERE_API_URL from haystack import Document from haystack.utils.auth import Secret from haystack_integrations.components.rankers.cohere import CohereRanker pytestmark = pytest.mark.ranker +COHERE_API_URL = "https://api.cohere.com" @pytest.fixture diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index 0b1c80fea..ffbf280e9 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -4,12 +4,12 @@ import os import pytest -from cohere import COHERE_API_URL from haystack import Document from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder pytestmark = pytest.mark.embedders +COHERE_API_URL = "https://api.cohere.com" class TestCohereDocumentEmbedder: @@ -21,7 +21,6 @@ def test_init_default(self): assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False - assert embedder.max_retries == 3 assert embedder.timeout == 120 assert embedder.batch_size == 32 assert embedder.progress_bar is True @@ -36,7 +35,6 @@ def test_init_with_parameters(self): api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, - max_retries=5, timeout=60, batch_size=64, progress_bar=False, @@ -49,7 +47,6 @@ def test_init_with_parameters(self): assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True - assert embedder.max_retries == 5 assert embedder.timeout == 60 assert embedder.batch_size == 64 assert embedder.progress_bar is False @@ -68,7 +65,6 @@ def test_to_dict(self): "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, - "max_retries": 3, "timeout": 120, "batch_size": 32, "progress_bar": True, @@ -85,7 +81,6 @@ def test_to_dict_with_custom_init_parameters(self): api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, - max_retries=5, timeout=60, batch_size=64, progress_bar=False, @@ -102,7 +97,6 @@ def test_to_dict_with_custom_init_parameters(self): "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True, - "max_retries": 5, "timeout": 60, "batch_size": 64, "progress_bar": False, diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index c59fe177b..b4f3e234c 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -4,11 +4,11 @@ import os import pytest -from cohere import COHERE_API_URL from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereTextEmbedder pytestmark = pytest.mark.embedders +COHERE_API_URL = "https://api.cohere.com" class TestCohereTextEmbedder: @@ -24,7 +24,6 @@ def test_init_default(self): assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False - assert embedder.max_retries == 3 assert embedder.timeout == 120 def test_init_with_parameters(self): @@ -38,7 +37,6 @@ def test_init_with_parameters(self): api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, - max_retries=5, timeout=60, ) assert embedder.api_key == Secret.from_token("test-api-key") @@ -47,7 +45,6 @@ def test_init_with_parameters(self): assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True - assert embedder.max_retries == 5 assert embedder.timeout == 60 def test_to_dict(self): @@ -65,7 +62,6 @@ def test_to_dict(self): "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, - "max_retries": 3, "timeout": 120, }, } @@ -81,7 +77,6 @@ def test_to_dict_with_custom_init_parameters(self): api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, - max_retries=5, timeout=60, ) component_dict = embedder_component.to_dict() @@ -94,7 +89,6 @@ def test_to_dict_with_custom_init_parameters(self): "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True, - "max_retries": 5, "timeout": 60, }, } diff --git a/integrations/deepeval/pydoc/config.yml b/integrations/deepeval/pydoc/config.yml index affa23acd..b3372f42c 100644 --- a/integrations/deepeval/pydoc/config.yml +++ b/integrations/deepeval/pydoc/config.yml @@ -18,7 +18,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: DeepEval integration for Haystack category_slug: integrations-api title: DeepEval diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 1da5cd820..44d89cb11 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -41,12 +41,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/deepeval-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -57,7 +59,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -139,12 +141,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] diff --git a/integrations/elasticsearch/CHANGELOG.md b/integrations/elasticsearch/CHANGELOG.md new file mode 100644 index 000000000..a825234bc --- /dev/null +++ b/integrations/elasticsearch/CHANGELOG.md @@ -0,0 +1,90 @@ +# Changelog + +## [unreleased] + +### 🚀 Features + +- Defer the database connection to when it's needed (#766) +- Add filter_policy to elasticsearch integration (#825) + +### 🐛 Bug Fixes + +- `ElasticSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#898) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/elasticsearch-v0.5.0] - 2024-05-24 + +### 🐛 Bug Fixes + +- Add support for custom mapping in ElasticsearchDocumentStore (#721) + +## [integrations/elasticsearch-v0.4.0] - 2024-04-03 + +### 📚 Documentation + +- Docstring update (#525) +- Review Elastic (#541) +- Disable-class-def (#556) + +## [integrations/elasticsearch-v0.3.0] - 2024-02-23 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +### Elasticsearch + +- Add user-agent header (#457) + +### Feat + +- Add filters to run function in retrievers of elasticsearch (#440) + +### Elasticsearch + +- Generate api docs (#322) + +## [integrations/elasticsearch-v0.2.0] - 2024-01-19 + +## [integrations/elasticsearch-v0.1.3] - 2024-01-18 + +## [integrations/elasticsearch-v0.1.2] - 2023-12-20 + +### 🐛 Bug Fixes + +- Fix project urls (#96) + + + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + +## [integrations/elasticsearch-v0.1.1] - 2023-12-05 + +### 🐛 Bug Fixes + +- Fix import and increase version (#77) + + + +## [integrations/elasticsearch-v0.1.0] - 2023-12-04 + +### 🐛 Bug Fixes + +- Fix license headers + + +## [integrations/elasticsearch-v0.0.2] - 2023-11-29 + + diff --git a/integrations/elasticsearch/pydoc/config.yml b/integrations/elasticsearch/pydoc/config.yml index 04e20f992..39ffb2e5f 100644 --- a/integrations/elasticsearch/pydoc/config.yml +++ b/integrations/elasticsearch/pydoc/config.yml @@ -17,7 +17,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Elasticsearch integration for Haystack category_slug: integrations-api title: Elasticsearch diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index 1612601e2..cb9281030 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "Silvano Cerza", email = "silvanocerza@gmail.com" }, -] +authors = [{ name = "Silvano Cerza", email = "silvanocerza@gmail.com" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -24,10 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "elasticsearch>=8,<9", -] +dependencies = ["haystack-ai", "elasticsearch>=8,<9"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/elasticsearch#readme" @@ -49,49 +44,30 @@ git_describe_command = 'git describe --tags --match="integrations/elasticsearch- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "pytest-xdist", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -137,9 +113,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -164,25 +146,14 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] minversion = "6.0" -markers = [ - "unit: unit tests", - "integration: integration tests" -] +markers = ["unit: unit tests", "integration: integration tests"] [[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pytest.*" -] +module = ["haystack.*", "haystack_integrations.*", "pytest.*"] ignore_missing_imports = true diff --git a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py index 867d49c0e..f273c955b 100644 --- a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py +++ b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/bm25_retriever.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.elasticsearch.document_store import ElasticsearchDocumentStore @@ -48,6 +50,7 @@ def __init__( fuzziness: str = "AUTO", top_k: int = 10, scale_score: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Initialize ElasticsearchBM25Retriever with an instance ElasticsearchDocumentStore. @@ -60,6 +63,7 @@ def __init__( for more details. :param top_k: Maximum number of Documents to return. :param scale_score: If `True` scales the Document`s scores between 0 and 1. + :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of `ElasticsearchDocumentStore`. """ @@ -72,6 +76,7 @@ def __init__( self._fuzziness = fuzziness self._top_k = top_k self._scale_score = scale_score + self._filter_policy = FilterPolicy.from_str(filter_policy) if isinstance(filter_policy, str) else filter_policy def to_dict(self) -> Dict[str, Any]: """ @@ -86,6 +91,7 @@ def to_dict(self) -> Dict[str, Any]: fuzziness=self._fuzziness, top_k=self._top_k, scale_score=self._scale_score, + filter_policy=self._filter_policy.value, document_store=self._document_store.to_dict(), ) @@ -102,6 +108,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchBM25Retriever": data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -110,14 +120,17 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio Retrieve documents using the BM25 keyword-based algorithm. :param query: String to search in `Document`s' text. - :param filters: Filters applied to the retrieved `Document`s. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: Maximum number of `Document` to return. :returns: A dictionary with the following keys: - `documents`: List of `Document`s that match the query. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) docs = self._document_store._bm25_retrieval( query=query, - filters=filters or self._filters, + filters=filters, fuzziness=self._fuzziness, top_k=top_k or self._top_k, scale_score=self._scale_score, diff --git a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py index fa292fe63..10e860ea4 100644 --- a/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py +++ b/integrations/elasticsearch/src/haystack_integrations/components/retrievers/elasticsearch/embedding_retriever.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.elasticsearch.document_store import ElasticsearchDocumentStore @@ -49,6 +51,7 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, num_candidates: Optional[int] = None, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Create the ElasticsearchEmbeddingRetriever component. @@ -61,6 +64,7 @@ def __init__( Increasing this value will improve search accuracy at the cost of slower search speeds. You can read more about it in the Elasticsearch [documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#tune-approximate-knn-for-speed-accuracy) + :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of ElasticsearchDocumentStore. """ if not isinstance(document_store, ElasticsearchDocumentStore): @@ -71,6 +75,7 @@ def __init__( self._filters = filters or {} self._top_k = top_k self._num_candidates = num_candidates + self._filter_policy = FilterPolicy.from_str(filter_policy) if isinstance(filter_policy, str) else filter_policy def to_dict(self) -> Dict[str, Any]: """ @@ -84,6 +89,7 @@ def to_dict(self) -> Dict[str, Any]: filters=self._filters, top_k=self._top_k, num_candidates=self._num_candidates, + filter_policy=self._filter_policy.value, document_store=self._document_store.to_dict(), ) @@ -100,6 +106,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchEmbeddingRetriever": data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -108,14 +118,17 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = Retrieve documents using a vector similarity metric. :param query_embedding: Embedding of the query. - :param filters: Filters applied to the retrieved `Document`s. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: Maximum number of `Document`s to return. :returns: A dictionary with the following keys: - `documents`: List of `Document`s most similar to the given `query_embedding` """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) docs = self._document_store._embedding_retrieval( query_embedding=query_embedding, - filters=filters or self._filters, + filters=filters, top_k=top_k or self._top_k, num_candidates=self._num_candidates, ) diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py index 12407a3dd..11016e3fc 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py @@ -63,6 +63,7 @@ def __init__( self, *, hosts: Optional[Hosts] = None, + custom_mapping: Optional[Dict[str, Any]] = None, index: str = "default", embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine", **kwargs, @@ -82,6 +83,7 @@ def __init__( [reference](https://elasticsearch-py.readthedocs.io/en/stable/api.html#module-elasticsearch) :param hosts: List of hosts running the Elasticsearch client. + :param custom_mapping: Custom mapping for the index. If not provided, a default mapping will be used. :param index: Name of index in Elasticsearch. :param embedding_similarity_function: The similarity function used to compare Documents embeddings. This parameter only takes effect if the index does not yet exist and is created. @@ -91,40 +93,60 @@ def __init__( :param **kwargs: Optional arguments that `Elasticsearch` takes. """ self._hosts = hosts - self._client = Elasticsearch( - hosts, - headers={"user-agent": f"haystack-py-ds/{haystack_version}"}, - **kwargs, - ) + self._client = None self._index = index self._embedding_similarity_function = embedding_similarity_function + self._custom_mapping = custom_mapping self._kwargs = kwargs - # Check client connection, this will raise if not connected - self._client.info() + if self._custom_mapping and not isinstance(self._custom_mapping, Dict): + msg = "custom_mapping must be a dictionary" + raise ValueError(msg) - # configure mapping for the embedding field - mappings = { - "properties": { - "embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function}, - "content": {"type": "text"}, - }, - "dynamic_templates": [ - { - "strings": { - "path_match": "*", - "match_mapping_type": "string", - "mapping": { - "type": "keyword", + @property + def client(self) -> Elasticsearch: + if self._client is None: + client = Elasticsearch( + self._hosts, + headers={"user-agent": f"haystack-py-ds/{haystack_version}"}, + **self._kwargs, + ) + # Check client connection, this will raise if not connected + client.info() + + if self._custom_mapping: + mappings = self._custom_mapping + else: + # Configure mapping for the embedding field if none is provided + mappings = { + "properties": { + "embedding": { + "type": "dense_vector", + "index": True, + "similarity": self._embedding_similarity_function, }, - } + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "path_match": "*", + "match_mapping_type": "string", + "mapping": { + "type": "keyword", + }, + } + } + ], } - ], - } - # Create the index if it doesn't exist - if not self._client.indices.exists(index=index): - self._client.indices.create(index=index, mappings=mappings) + # Create the index if it doesn't exist + if not client.indices.exists(index=self._index): + client.indices.create(index=self._index, mappings=mappings) + + self._client = client + + return self._client def to_dict(self) -> Dict[str, Any]: """ @@ -139,6 +161,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, hosts=self._hosts, + custom_mapping=self._custom_mapping, index=self._index, embedding_similarity_function=self._embedding_similarity_function, **self._kwargs, @@ -161,7 +184,7 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. :returns: Number of documents in the document store. """ - return self._client.count(index=self._index)["count"] + return self.client.count(index=self._index)["count"] def _search_documents(self, **kwargs) -> List[Document]: """ @@ -176,7 +199,7 @@ def _search_documents(self, **kwargs) -> List[Document]: from_ = 0 # Handle pagination while True: - res = self._client.search( + res = self.client.search( index=self._index, from_=from_, **kwargs, @@ -250,7 +273,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D ) documents_written, errors = helpers.bulk( - client=self._client, + client=self.client, actions=elasticsearch_actions, refresh="wait_for", index=self._index, @@ -306,7 +329,7 @@ def delete_documents(self, document_ids: List[str]) -> None: """ helpers.bulk( - client=self._client, + client=self.client, actions=({"_op_type": "delete", "_id": id_} for id_ in document_ids), refresh="wait_for", index=self._index, diff --git a/integrations/elasticsearch/tests/test_bm25_retriever.py b/integrations/elasticsearch/tests/test_bm25_retriever.py index dd88cd0a8..3e9ebc9b8 100644 --- a/integrations/elasticsearch/tests/test_bm25_retriever.py +++ b/integrations/elasticsearch/tests/test_bm25_retriever.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchBM25Retriever from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore @@ -14,8 +16,15 @@ def test_init_default(): assert retriever._document_store == mock_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE assert not retriever._scale_score + retriever = ElasticsearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + ElasticsearchBM25Retriever(document_store=mock_store, filter_policy="keep") + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): @@ -28,6 +37,7 @@ def test_to_dict(_mock_elasticsearch_client): "document_store": { "init_parameters": { "hosts": "some fake host", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, @@ -37,12 +47,38 @@ def test_to_dict(_mock_elasticsearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": False, + "filter_policy": "replace", }, } @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict(_mock_elasticsearch_client): + data = { + "type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", + }, + "filters": {}, + "fuzziness": "AUTO", + "top_k": 10, + "scale_score": True, + "filter_policy": "replace", + }, + } + retriever = ElasticsearchBM25Retriever.from_dict(data) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._fuzziness == "AUTO" + assert retriever._top_k == 10 + assert retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_from_dict_no_filter_policy(_mock_elasticsearch_client): data = { "type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever", "init_parameters": { @@ -62,6 +98,7 @@ def test_from_dict(_mock_elasticsearch_client): assert retriever._fuzziness == "AUTO" assert retriever._top_k == 10 assert retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE def test_run(): diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index 308486a78..20b68f126 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -4,7 +4,7 @@ import random from typing import List -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from elasticsearch.exceptions import BadRequestError # type: ignore[import-not-found] @@ -15,6 +15,12 @@ from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_init_is_lazy(_mock_es_client): + ElasticsearchDocumentStore(hosts="testhost") + _mock_es_client.assert_not_called() + + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore(hosts="some hosts") @@ -23,6 +29,7 @@ def test_to_dict(_mock_elasticsearch_client): "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", "init_parameters": { "hosts": "some hosts", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, @@ -35,6 +42,7 @@ def test_from_dict(_mock_elasticsearch_client): "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", "init_parameters": { "hosts": "some hosts", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, @@ -42,6 +50,7 @@ def test_from_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore.from_dict(data) assert document_store._hosts == "some hosts" assert document_store._index == "default" + assert document_store._custom_mapping is None assert document_store._embedding_similarity_function == "cosine" @@ -70,7 +79,7 @@ def document_store(self, request): hosts=hosts, index=index, embedding_similarity_function=embedding_similarity_function ) yield store - store._client.options(ignore_status=[400, 404]).indices.delete(index=index) + store.client.options(ignore_status=[400, 404]).indices.delete(index=index) def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ @@ -98,7 +107,7 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do super().assert_documents_are_equal(received, expected) def test_user_agent_header(self, document_store: ElasticsearchDocumentStore): - assert document_store._client._headers["user-agent"].startswith("haystack-py-ds/") + assert document_store.client._headers["user-agent"].startswith("haystack-py-ds/") def test_write_documents(self, document_store: ElasticsearchDocumentStore): docs = [Document(id="1")] @@ -280,3 +289,33 @@ def test_write_documents_different_embedding_sizes_fail(self, document_store: El with pytest.raises(DocumentStoreError): document_store.write_documents(docs) + + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") + def test_init_with_custom_mapping(self, mock_elasticsearch): + custom_mapping = { + "properties": { + "embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"}, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "path_match": "*", + "match_mapping_type": "string", + "mapping": { + "type": "keyword", + }, + } + } + ], + } + mock_client = Mock( + indices=Mock(create=Mock(), exists=Mock(return_value=False)), + ) + mock_elasticsearch.return_value = mock_client + + _ = ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping).client + mock_client.indices.create.assert_called_once_with( + index="default", + mappings=custom_mapping, + ) diff --git a/integrations/elasticsearch/tests/test_embedding_retriever.py b/integrations/elasticsearch/tests/test_embedding_retriever.py index f632c3655..2d03f0ec2 100644 --- a/integrations/elasticsearch/tests/test_embedding_retriever.py +++ b/integrations/elasticsearch/tests/test_embedding_retriever.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchEmbeddingRetriever from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore @@ -16,6 +18,12 @@ def test_init_default(): assert retriever._top_k == 10 assert retriever._num_candidates is None + retriever = ElasticsearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + ElasticsearchEmbeddingRetriever(document_store=mock_store, filter_policy="keep") + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): @@ -29,6 +37,7 @@ def test_to_dict(_mock_elasticsearch_client): "document_store": { "init_parameters": { "hosts": "some fake host", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, @@ -36,6 +45,7 @@ def test_to_dict(_mock_elasticsearch_client): }, "filters": {}, "top_k": 10, + "filter_policy": "replace", "num_candidates": None, }, } @@ -43,6 +53,29 @@ def test_to_dict(_mock_elasticsearch_client): @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_from_dict(_mock_elasticsearch_client): + t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "filter_policy": "replace", + "num_candidates": None, + }, + } + retriever = ElasticsearchEmbeddingRetriever.from_dict(data) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._num_candidates is None + + +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_from_dict_no_filter_policy(_mock_elasticsearch_client): t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever" data = { "type": t, @@ -61,6 +94,7 @@ def test_from_dict(_mock_elasticsearch_client): assert retriever._filters == {} assert retriever._top_k == 10 assert retriever._num_candidates is None + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE def test_run(): diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md new file mode 100644 index 000000000..9ae3da929 --- /dev/null +++ b/integrations/fastembed/CHANGELOG.md @@ -0,0 +1,63 @@ +# Changelog + +## [unreleased] + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +### Fix + +- Typo on Sparse embedders. The parameter should be "progress_bar" … (#814) + +## [integrations/fastembed-v1.1.0] - 2024-05-15 + +## [integrations/fastembed-v1.0.0] - 2024-05-06 + +## [integrations/fastembed-v0.1.0] - 2024-04-10 + +### 🚀 Features + +- *(FastEmbed)* Support for SPLADE Sparse Embedder (#579) + +### 📚 Documentation + +- Disable-class-def (#556) + +## [integrations/fastembed-v0.0.6] - 2024-03-07 + +### 📚 Documentation + +- Review and normalize docstrings - `integrations.fastembed` (#519) +- Small consistency improvements (#536) + +## [integrations/fastembed-v0.0.5] - 2024-02-20 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +## [integrations/fastembed-v0.0.4] - 2024-02-16 + +## [integrations/fastembed-v0.0.3] - 2024-02-12 + +### 🐛 Bug Fixes + +- From numpy float to float (#391) + +### 📚 Documentation + +- Update paths and titles (#397) + +## [integrations/fastembed-v0.0.2] - 2024-02-11 + +## [integrations/fastembed-v0.0.1] - 2024-02-10 + + diff --git a/integrations/fastembed/pydoc/config.yml b/integrations/fastembed/pydoc/config.yml index c8bd11762..aad50e52c 100644 --- a/integrations/fastembed/pydoc/config.yml +++ b/integrations/fastembed/pydoc/config.yml @@ -18,7 +18,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: FastEmbed integration for Haystack category_slug: integrations-api title: FastEmbed diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index 538972651..9afd344c9 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -25,10 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ -"haystack-ai>=2.0.1", -"fastembed>=0.2.5", -] +dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.2.5"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -50,50 +45,30 @@ git_describe_command = 'git describe --tags --match="integrations/fastembed-v[0- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "ipython", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", - "numpy" -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -135,11 +110,18 @@ ignore = [ "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", - "FBT001", "FBT002", + "FBT001", + "FBT002", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -165,12 +147,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -178,6 +156,6 @@ module = [ "haystack_integrations.*", "fastembed.*", "pytest.*", - "numpy.*" + "numpy.*", ] ignore_missing_imports = true diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index 2fc7c5ca2..66f797549 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -19,13 +19,16 @@ def get_embedding_backend( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, + local_files_only: bool = False, ): embedding_backend_id = f"{model_name}{cache_dir}{threads}" if embedding_backend_id in _FastembedEmbeddingBackendFactory._instances: return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] - embedding_backend = _FastembedEmbeddingBackend(model_name=model_name, cache_dir=cache_dir, threads=threads) + embedding_backend = _FastembedEmbeddingBackend( + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + ) _FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -40,8 +43,11 @@ def __init__( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, + local_files_only: bool = False, ): - self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads) + self.model = TextEmbedding( + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + ) def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]: # the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists @@ -66,6 +72,7 @@ def get_embedding_backend( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, + local_files_only: bool = False, ): embedding_backend_id = f"{model_name}{cache_dir}{threads}" @@ -73,7 +80,7 @@ def get_embedding_backend( return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedSparseEmbeddingBackend( - model_name=model_name, cache_dir=cache_dir, threads=threads + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only ) _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -89,8 +96,11 @@ def __init__( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, + local_files_only: bool = False, ): - self.model = SparseTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads) + self.model = SparseTextEmbedding( + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + ) def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]: # The embed method returns a Iterable[SparseEmbedding], so we convert to Haystack SparseEmbedding type. diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index ec0b918d9..8b63582c5 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -29,14 +29,18 @@ class FastembedDocumentEmbedder: # Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa) document_list = [ Document( - content="Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint destruction. Radical species with oxidative activity, including reactive nitrogen species, represent mediators of inflammation and cartilage damage.", + content=("Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint " + "destruction. Radical species with oxidative activity, including reactive nitrogen species, " + "represent mediators of inflammation and cartilage damage."), meta={ "pubid": "25,445,628", "long_answer": "yes", }, ), Document( - content="Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion and actions are still poorly understood.", + content=("Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic " + "islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion " + "and actions are still poorly understood."), meta={ "pubid": "25,445,712", "long_answer": "yes", @@ -49,7 +53,7 @@ class FastembedDocumentEmbedder: print(f"Document Embedding: {result['documents'][0].embedding}") print(f"Embedding Dimension: {len(result['documents'][0].embedding)}") ``` - """ # noqa: E501 + """ def __init__( self, @@ -61,6 +65,7 @@ def __init__( batch_size: int = 256, progress_bar: bool = True, parallel: Optional[int] = None, + local_files_only: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", ): @@ -76,11 +81,12 @@ def __init__( :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param batch_size: Number of strings to encode at once. - :param progress_bar: If true, displays progress bar during embedding. + :param progress_bar: If `True`, displays progress bar during embedding. :param parallel: If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. """ @@ -93,6 +99,7 @@ def __init__( self.batch_size = batch_size self.progress_bar = progress_bar self.parallel = parallel + self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator @@ -112,6 +119,7 @@ def to_dict(self) -> Dict[str, Any]: batch_size=self.batch_size, progress_bar=self.progress_bar, parallel=self.parallel, + local_files_only=self.local_files_only, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, ) @@ -122,7 +130,10 @@ def warm_up(self): """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index ed5a3208b..4b72389fa 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -12,30 +12,31 @@ class FastembedSparseDocumentEmbedder: Usage example: ```python - # To use this component, install the "fastembed-haystack" package. - # pip install fastembed-haystack - from haystack_integrations.components.embedders.fastembed import FastembedSparseDocumentEmbedder from haystack.dataclasses import Document - doc_embedder = FastembedSparseDocumentEmbedder( + sparse_doc_embedder = FastembedSparseDocumentEmbedder( model="prithvida/Splade_PP_en_v1", batch_size=32, ) - doc_embedder.warm_up() + sparse_doc_embedder.warm_up() # Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa) document_list = [ Document( - content="Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint destruction. Radical species with oxidative activity, including reactive nitrogen species, represent mediators of inflammation and cartilage damage.", + content=("Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint " + "destruction. Radical species with oxidative activity, including reactive nitrogen species, " + "represent mediators of inflammation and cartilage damage."), meta={ "pubid": "25,445,628", "long_answer": "yes", }, ), Document( - content="Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion and actions are still poorly understood.", + content=("Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic " + "islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion " + "and actions are still poorly understood."), meta={ "pubid": "25,445,712", "long_answer": "yes", @@ -43,12 +44,12 @@ class FastembedSparseDocumentEmbedder: ), ] - result = doc_embedder.run(document_list) + result = sparse_doc_embedder.run(document_list) print(f"Document Text: {result['documents'][0].content}") - print(f"Document Embedding: {result['documents'][0].sparse_embedding}") - print(f"Embedding Dimension: {len(result['documents'][0].sparse_embedding)}") + print(f"Document Sparse Embedding: {result['documents'][0].sparse_embedding}") + print(f"Sparse Embedding Dimension: {len(result['documents'][0].sparse_embedding)}") ``` - """ # noqa: E501 + """ def __init__( self, @@ -58,6 +59,7 @@ def __init__( batch_size: int = 32, progress_bar: bool = True, parallel: Optional[int] = None, + local_files_only: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", ): @@ -76,6 +78,7 @@ def __init__( If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. """ @@ -86,6 +89,7 @@ def __init__( self.batch_size = batch_size self.progress_bar = progress_bar self.parallel = parallel + self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator @@ -103,6 +107,7 @@ def to_dict(self) -> Dict[str, Any]: batch_size=self.batch_size, progress_bar=self.progress_bar, parallel=self.parallel, + local_files_only=self.local_files_only, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, ) @@ -113,7 +118,10 @@ def warm_up(self): """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @@ -151,7 +159,7 @@ def run(self, documents: List[Document]): embeddings = self.embedding_backend.embed( texts_to_embed, batch_size=self.batch_size, - show_progress_bar=self.progress_bar, + progress_bar=self.progress_bar, parallel=self.parallel, ) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index b31677785..67348b2bd 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -13,30 +13,28 @@ class FastembedSparseTextEmbedder: Usage example: ```python - # To use this component, install the "fastembed-haystack" package. - # pip install fastembed-haystack - from haystack_integrations.components.embedders.fastembed import FastembedSparseTextEmbedder - text = "It clearly says online this will work on a Mac OS system. The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!" + text = ("It clearly says online this will work on a Mac OS system. " + "The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!") - text_embedder = FastembedSparseTextEmbedder( + sparse_text_embedder = FastembedSparseTextEmbedder( model="prithvida/Splade_PP_en_v1" ) - text_embedder.warm_up() + sparse_text_embedder.warm_up() - embedding = text_embedder.run(text)["embedding"] + sparse_embedding = sparse_text_embedder.run(text)["sparse_embedding"] ``` - """ # noqa: E501 + """ def __init__( self, model: str = "prithvida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, - batch_size: int = 32, progress_bar: bool = True, parallel: Optional[int] = None, + local_files_only: bool = False, ): """ Create a FastembedSparseTextEmbedder component. @@ -46,20 +44,20 @@ def __init__( Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. :param threads: The number of threads single onnxruntime session can use. Defaults to None. - :param batch_size: Number of strings to encode at once. - :param progress_bar: If true, displays progress bar during embedding. + :param progress_bar: If `True`, displays progress bar during embedding. :param parallel: If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. """ self.model_name = model self.cache_dir = cache_dir self.threads = threads - self.batch_size = batch_size self.progress_bar = progress_bar self.parallel = parallel + self.local_files_only = local_files_only def to_dict(self) -> Dict[str, Any]: """ @@ -73,9 +71,9 @@ def to_dict(self) -> Dict[str, Any]: model=self.model_name, cache_dir=self.cache_dir, threads=self.threads, - batch_size=self.batch_size, progress_bar=self.progress_bar, parallel=self.parallel, + local_files_only=self.local_files_only, ) def warm_up(self): @@ -84,7 +82,10 @@ def warm_up(self): """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) @component.output_types(sparse_embedding=SparseEmbedding) @@ -110,8 +111,7 @@ def run(self, text: str): embedding = self.embedding_backend.embed( [text], - batch_size=self.batch_size, - show_progress_bar=self.progress_bar, + progress_bar=self.progress_bar, parallel=self.parallel, )[0] return {"sparse_embedding": embedding} diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index 9bc4475a5..a7f56ff97 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -12,12 +12,10 @@ class FastembedTextEmbedder: Usage example: ```python - # To use this component, install the "fastembed-haystack" package. - # pip install fastembed-haystack - from haystack_integrations.components.embedders.fastembed import FastembedTextEmbedder - text = "It clearly says online this will work on a Mac OS system. The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!" + text = ("It clearly says online this will work on a Mac OS system. " + "The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!") text_embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5" @@ -26,7 +24,7 @@ class FastembedTextEmbedder: embedding = text_embedder.run(text)["embedding"] ``` - """ # noqa: E501 + """ def __init__( self, @@ -37,6 +35,7 @@ def __init__( suffix: str = "", progress_bar: bool = True, parallel: Optional[int] = None, + local_files_only: bool = False, ): """ Create a FastembedTextEmbedder component. @@ -48,11 +47,12 @@ def __init__( :param threads: The number of threads single onnxruntime session can use. Defaults to None. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. - :param progress_bar: If true, displays progress bar during embedding. + :param progress_bar: If `True`, displays progress bar during embedding. :param parallel: If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. """ self.model_name = model @@ -62,6 +62,7 @@ def __init__( self.suffix = suffix self.progress_bar = progress_bar self.parallel = parallel + self.local_files_only = local_files_only def to_dict(self) -> Dict[str, Any]: """ @@ -79,6 +80,7 @@ def to_dict(self) -> Dict[str, Any]: suffix=self.suffix, progress_bar=self.progress_bar, parallel=self.parallel, + local_files_only=self.local_files_only, ) def warm_up(self): @@ -87,7 +89,10 @@ def warm_up(self): """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend( - model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, ) @component.output_types(embedding=List[float]) diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 4dad9525d..631d9f1e0 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -27,7 +27,9 @@ def test_model_initialization(mock_instructor): _FastembedEmbeddingBackendFactory.get_embedding_backend( model_name="BAAI/bge-small-en-v1.5", ) - mock_instructor.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None) + mock_instructor.assert_called_once_with( + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False + ) # restore the factory state _FastembedEmbeddingBackendFactory._instances = {} diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index 75fdcc9c9..8afb89c69 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -22,6 +22,7 @@ def test_init_default(self): assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.parallel is None + assert not embedder.local_files_only assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" @@ -38,6 +39,7 @@ def test_init_with_parameters(self): batch_size=64, progress_bar=False, parallel=1, + local_files_only=True, meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) @@ -49,6 +51,7 @@ def test_init_with_parameters(self): assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 + assert embedder.local_files_only assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " @@ -69,6 +72,7 @@ def test_to_dict(self): "batch_size": 256, "progress_bar": True, "parallel": None, + "local_files_only": False, "embedding_separator": "\n", "meta_fields_to_embed": [], }, @@ -87,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self): batch_size=64, progress_bar=False, parallel=1, + local_files_only=True, meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) @@ -102,6 +107,7 @@ def test_to_dict_with_custom_init_parameters(self): "batch_size": 64, "progress_bar": False, "parallel": 1, + "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", }, @@ -122,6 +128,7 @@ def test_from_dict(self): "batch_size": 256, "progress_bar": True, "parallel": None, + "local_files_only": False, "meta_fields_to_embed": [], "embedding_separator": "\n", }, @@ -135,6 +142,7 @@ def test_from_dict(self): assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.parallel is None + assert not embedder.local_files_only assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" @@ -153,6 +161,7 @@ def test_from_dict_with_custom_init_parameters(self): "batch_size": 64, "progress_bar": False, "parallel": 1, + "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", }, @@ -166,6 +175,7 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 + assert embedder.local_files_only assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " @@ -180,7 +190,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False ) @patch( diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index 756eeb4b5..b4caca364 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -21,6 +21,7 @@ def test_init_default(self): assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.parallel is None + assert not embedder.local_files_only assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" @@ -35,6 +36,7 @@ def test_init_with_parameters(self): batch_size=64, progress_bar=False, parallel=1, + local_files_only=True, meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) @@ -44,6 +46,7 @@ def test_init_with_parameters(self): assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 + assert embedder.local_files_only assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " @@ -62,6 +65,7 @@ def test_to_dict(self): "batch_size": 32, "progress_bar": True, "parallel": None, + "local_files_only": False, "embedding_separator": "\n", "meta_fields_to_embed": [], }, @@ -78,6 +82,7 @@ def test_to_dict_with_custom_init_parameters(self): batch_size=64, progress_bar=False, parallel=1, + local_files_only=True, meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) @@ -91,6 +96,7 @@ def test_to_dict_with_custom_init_parameters(self): "batch_size": 64, "progress_bar": False, "parallel": 1, + "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", }, @@ -110,6 +116,7 @@ def test_from_dict(self): "batch_size": 32, "progress_bar": True, "parallel": None, + "local_files_only": False, "meta_fields_to_embed": [], "embedding_separator": "\n", }, @@ -121,6 +128,7 @@ def test_from_dict(self): assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.parallel is None + assert not embedder.local_files_only assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" @@ -138,6 +146,7 @@ def test_from_dict_with_custom_init_parameters(self): "batch_size": 64, "progress_bar": False, "parallel": 1, + "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", }, @@ -149,6 +158,7 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 + assert embedder.local_files_only assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " @@ -163,7 +173,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None + model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False ) @patch( @@ -260,7 +270,7 @@ def test_embed_metadata(self): "meta_value 4\ndocument-number 4", ], batch_size=32, - show_progress_bar=True, + progress_bar=True, parallel=None, ) diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 3751eea14..9e37df409 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -18,7 +18,6 @@ def test_init_default(self): assert embedder.model_name == "prithvida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None - assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.parallel is None @@ -30,14 +29,12 @@ def test_init_with_parameters(self): model="prithvida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, - batch_size=64, progress_bar=False, parallel=1, ) assert embedder.model_name == "prithvida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 - assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 @@ -53,9 +50,9 @@ def test_to_dict(self): "model": "prithvida/Splade_PP_en_v1", "cache_dir": None, "threads": None, - "batch_size": 32, "progress_bar": True, "parallel": None, + "local_files_only": False, }, } @@ -67,9 +64,9 @@ def test_to_dict_with_custom_init_parameters(self): model="prithvida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, - batch_size=64, progress_bar=False, parallel=1, + local_files_only=True, ) embedder_dict = embedder.to_dict() assert embedder_dict == { @@ -78,9 +75,9 @@ def test_to_dict_with_custom_init_parameters(self): "model": "prithvida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, - "batch_size": 64, "progress_bar": False, "parallel": 1, + "local_files_only": True, }, } @@ -94,7 +91,6 @@ def test_from_dict(self): "model": "prithvida/Splade_PP_en_v1", "cache_dir": None, "threads": None, - "batch_size": 32, "progress_bar": True, "parallel": None, }, @@ -103,7 +99,6 @@ def test_from_dict(self): assert embedder.model_name == "prithvida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None - assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.parallel is None @@ -117,7 +112,6 @@ def test_from_dict_with_custom_init_parameters(self): "model": "prithvida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, - "batch_size": 64, "progress_bar": False, "parallel": 1, }, @@ -126,7 +120,6 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.model_name == "prithvida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 - assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 @@ -141,7 +134,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None + model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False ) @patch( diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index 402980485..f20a98b57 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -59,6 +59,7 @@ def test_to_dict(self): "suffix": "", "progress_bar": True, "parallel": None, + "local_files_only": False, }, } @@ -74,6 +75,7 @@ def test_to_dict_with_custom_init_parameters(self): suffix="suffix", progress_bar=False, parallel=1, + local_files_only=True, ) embedder_dict = embedder.to_dict() assert embedder_dict == { @@ -86,6 +88,7 @@ def test_to_dict_with_custom_init_parameters(self): "suffix": "suffix", "progress_bar": False, "parallel": 1, + "local_files_only": True, }, } @@ -150,7 +153,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None + model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False ) @patch( diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md new file mode 100644 index 000000000..cbdd97046 --- /dev/null +++ b/integrations/google_ai/CHANGELOG.md @@ -0,0 +1,43 @@ +# Changelog + +## [unreleased] + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/google_ai-v1.1.0] - 2024-06-05 + +### 🐛 Bug Fixes + +- Handle `TypeError: Could not create Blob` in `GoogleAIGeminiChatGenerator` (#772) + +## [integrations/google_ai-v1.0.0] - 2024-03-27 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Disable-class-def (#556) + +## [integrations/google_ai-v0.2.0] - 2024-02-15 + +### Google_ai + +- Create api docs (#354) + +## [integrations/google_ai-v0.1.0] - 2024-01-25 + +### Refact + +- [**breaking**] Adjust import paths (#268) + +## [integrations/google_ai-v0.0.1] - 2024-01-03 + + diff --git a/integrations/google_ai/pydoc/config.yml b/integrations/google_ai/pydoc/config.yml index 977a91fab..c2939a812 100644 --- a/integrations/google_ai/pydoc/config.yml +++ b/integrations/google_ai/pydoc/config.yml @@ -15,7 +15,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Google AI integration for Haystack category_slug: integrations-api title: Google AI diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index d7a31fa74..db958a487 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -24,10 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "google-generativeai>=0.3.1" -] +dependencies = ["haystack-ai", "google-generativeai>=0.3.1"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_ai_haystack#readme" @@ -49,47 +44,28 @@ git_describe_command = 'git describe --tags --match="integrations/google_ai-v[0- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -132,9 +108,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -159,12 +141,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -174,4 +152,4 @@ module = [ "pytest.*", "numpy.*", ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index c55fcab67..dd065af4b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -2,9 +2,10 @@ from typing import Any, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part +from google.ai.generativelanguage import Tool as ToolProto from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.byte_stream import ByteStream @@ -17,10 +18,16 @@ @component class GoogleAIGeminiChatGenerator: """ - `GoogleAIGeminiChatGenerator` is a multimodal generator supporting Gemini via Google AI Studio. - It uses the `ChatMessage` dataclass to interact with the model. + Completes chats using multimodal Gemini models through Google AI Studio. + + It uses the [`ChatMessage`](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) + dataclass to interact with the model. You can use the following models: + - gemini-pro + - gemini-ultra + - gemini-pro-vision + + ### Usage example - Usage example: ```python from haystack.utils import Secret from haystack.dataclasses.chat_message import ChatMessage @@ -41,7 +48,8 @@ class GoogleAIGeminiChatGenerator: ``` - Usage example with function calling: + #### With function calling: + ```python from haystack.utils import Secret from haystack.dataclasses.chat_message import ChatMessage @@ -110,11 +118,15 @@ def __init__( * `gemini-pro-vision` * `gemini-ultra` - :param api_key: Google AI Studio API key. - :param model: Name of the model to use. - :param generation_config: The generation config to use. - Can either be a `GenerationConfig` object or a dictionary of parameters. - For the available parameters, see + :param api_key: Google AI Studio API key. To get a key, + see [Google AI Studio](https://makersuite.google.com). + :param model: Name of the model to use. Supported models are: + - gemini-pro + - gemini-ultra + - gemini-pro-vision + :param generation_config: The generation configuration to use. + This can either be a `GenerationConfig` object or a dictionary of parameters. + For available parameters, see [the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig). :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. @@ -159,7 +171,14 @@ def to_dict(self) -> Dict[str, Any]: tools=self._tools, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] + data["init_parameters"]["tools"] = [] + for tool in tools: + if isinstance(tool, Tool): + # There are multiple Tool types in the Google lib, one that is a protobuf class and + # another is a simple Python class. They have a similar structure but the Python class + # can't be easily serializated to a dict. We need to convert it to a protobuf class first. + tool = tool.to_proto() # noqa: PLW2901 + data["init_parameters"]["tools"].append(ToolProto.serialize(tool)) if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -179,7 +198,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] + deserialized_tools = [] + for tool in tools: + # Tools are always serialized as a protobuf class, so we need to deserialize them first + # to be able to convert them to the Python class. + proto = ToolProto.deserialize(tool) + deserialized_tools.append( + Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution) + ) + data["init_parameters"]["tools"] = deserialized_tools if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -234,12 +261,10 @@ def _message_to_content(self, message: ChatMessage) -> Content: elif message.role == ChatRole.SYSTEM: part = Part() part.text = message.content - return part elif message.role == ChatRole.FUNCTION: part = Part() part.function_response.name = message.name part.function_response.response = message.content - return part elif message.role == ChatRole.USER: part = self._convert_part(message.content) else: diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index f7b2f9097..07277e55a 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -17,9 +17,10 @@ @component class GoogleAIGeminiGenerator: """ - `GoogleAIGeminiGenerator` is a multimodal generator supporting Gemini via Google AI Studio. + Generates text using multimodal Gemini models through Google AI Studio. + + ### Usage example - Usage example: ```python from haystack.utils import Secret from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator @@ -30,7 +31,8 @@ class GoogleAIGeminiGenerator: print(answer) ``` - Multimodal usage example: + #### Multimodal example + ```python import requests from haystack.utils import Secret @@ -81,9 +83,9 @@ def __init__( :param api_key: Google AI Studio API key. :param model: Name of the model to use. - :param generation_config: The generation config to use. - Can either be a `GenerationConfig` object or a dictionary of parameters. - For the available parameters, see + :param generation_config: The generation configuration to use. + This can either be a `GenerationConfig` object or a dictionary of parameters. + For available parameters, see [the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig). :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 238834a84..9b3124eab 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -2,9 +2,8 @@ from unittest.mock import patch import pytest -from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool from haystack.dataclasses.chat_message import ChatMessage from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator @@ -21,7 +20,7 @@ def test_init(monkeypatch): top_p=0.5, top_k=0.5, ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -67,9 +66,9 @@ def test_to_dict(monkeypatch): max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=2, ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -105,12 +104,12 @@ def test_to_dict(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 2, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], }, - "safety_settings": {6: 3}, + "safety_settings": {10: 3}, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -133,12 +132,12 @@ def test_from_dict(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 2, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], }, - "safety_settings": {6: 3}, + "safety_settings": {10: 3}, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -155,62 +154,34 @@ def test_from_dict(monkeypatch): max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=2, ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [ - Tool( - function_declarations=[ - FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - ] - ) - ] + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert len(gemini._tools) == 1 + assert len(gemini._tools[0].function_declarations) == 1 + assert gemini._tools[0].function_declarations[0].name == "get_current_weather" + assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location" + assert ( + gemini._tools[0].function_declarations[0].parameters.properties["location"].description + == "The city and state, e.g. San Francisco, CA" + ) + assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"] + assert gemini._tools[0].function_declarations[0].parameters.required == ["location"] assert isinstance(gemini._model, GenerativeModel) -@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY not set") +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_run(): # We're ignoring the unused function argument check since we must have that argument for the test # to run successfully, but we don't actually use it. def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 return {"weather": "sunny", "temperature": 21.8, "unit": unit} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], + get_current_weather_func = FunctionDeclaration.from_function( + get_current_weather, + descriptions={ + "location": "The city and state, e.g. San Francisco, CA", + "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -225,3 +196,15 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 res = gemini_chat.run(messages=messages) assert len(res["replies"]) > 0 + + +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") +def test_past_conversation(): + gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") + messages = [ + ChatMessage.from_user(content="What is 2+2?"), + ChatMessage.from_system(content="It's an arithmetic operation."), + ChatMessage.from_user(content="Yeah, but what's the result?"), + ] + res = gemini_chat.run(messages=messages) + assert len(res["replies"]) > 0 diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index dea1ca11d..35c7d196b 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -20,7 +20,7 @@ def test_init(monkeypatch): top_p=0.5, top_k=0.5, ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -64,9 +64,9 @@ def test_to_dict(monkeypatch): max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=2, ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} get_current_weather_func = FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -102,12 +102,12 @@ def test_to_dict(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 2, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], }, - "safety_settings": {6: 3}, + "safety_settings": {10: 3}, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -134,7 +134,7 @@ def test_from_dict(monkeypatch): "max_output_tokens": 10, "stop_sequences": ["stop"], }, - "safety_settings": {6: 3}, + "safety_settings": {10: 3}, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -153,7 +153,7 @@ def test_from_dict(monkeypatch): top_p=0.5, top_k=0.5, ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert gemini._tools == [ Tool( function_declarations=[ @@ -184,7 +184,7 @@ def test_from_dict(monkeypatch): assert isinstance(gemini._model, GenerativeModel) -@pytest.mark.skipif("GOOGLE_API_KEY" not in os.environ, reason="GOOGLE_API_KEY not set") +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_run(): gemini = GoogleAIGeminiGenerator(model="gemini-pro") res = gemini.run("Tell me something cool") diff --git a/integrations/google_vertex/pydoc/config.yml b/integrations/google_vertex/pydoc/config.yml index bee97fdb8..6e23164b9 100644 --- a/integrations/google_vertex/pydoc/config.yml +++ b/integrations/google_vertex/pydoc/config.yml @@ -20,7 +20,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Google Vertex integration for Haystack category_slug: integrations-api title: Google Vertex diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index 349bb06a0..747bbecbf 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -24,11 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "google-cloud-aiplatform>=1.38", - "pyarrow>3", -] +dependencies = ["haystack-ai", "google-cloud-aiplatform>=1.38", "pyarrow>3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_vertex#readme" @@ -50,47 +44,28 @@ git_describe_command = 'git describe --tags --match="integrations/google_vertex- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -130,9 +105,15 @@ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -157,12 +138,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] @@ -182,4 +159,4 @@ markers = [ "embedders: embedders tests", "generators: generators tests", ] -log_cli = true \ No newline at end of file +log_cli = true diff --git a/integrations/gradient/LICENSE.txt b/integrations/gradient/LICENSE.txt deleted file mode 100644 index de4c7f39f..000000000 --- a/integrations/gradient/LICENSE.txt +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2023 deepset GmbH - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/integrations/gradient/README.md b/integrations/gradient/README.md deleted file mode 100644 index e1a46114b..000000000 --- a/integrations/gradient/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# gradient-haystack - -[![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) - ------ - -**Table of Contents** - -- [gradient-haystack](#gradient-haystack) - - [Installation](#installation) - - [License](#license) - -## Installation - -```console -pip install gradient-haystack -``` - -## License - -`gradient-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/__init__.py b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/__init__.py deleted file mode 100644 index 7fbba1bab..000000000 --- a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from .gradient_document_embedder import GradientDocumentEmbedder -from .gradient_text_embedder import GradientTextEmbedder - -__all__ = ["GradientDocumentEmbedder", "GradientTextEmbedder"] diff --git a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py deleted file mode 100644 index a868c6c1b..000000000 --- a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_document_embedder.py +++ /dev/null @@ -1,174 +0,0 @@ -import logging -from typing import Any, Dict, List, Optional - -from gradientai import Gradient -from haystack import Document, component, default_from_dict, default_to_dict -from haystack.utils import Secret, deserialize_secrets_inplace - -tqdm_imported: bool = True -try: - from tqdm import tqdm -except ImportError: - tqdm_imported = False - - -logger = logging.getLogger(__name__) - - -def _alt_progress_bar(x: Any) -> Any: - return x - - -@component -class GradientDocumentEmbedder: - """ - A component for computing Document embeddings using Gradient AI API. - - The embedding of each Document is stored in the `embedding` field of the Document. - - Usage example: - ```python - from haystack import Pipeline - from haystack.document_stores.in_memory import InMemoryDocumentStore - from haystack.components.writers import DocumentWriter - from haystack import Document - - from haystack_integrations.components.embedders.gradient import GradientDocumentEmbedder - - documents = [ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome."), - ] - - indexing_pipeline = Pipeline() - indexing_pipeline.add_component(instance=GradientDocumentEmbedder(), name="document_embedder") - indexing_pipeline.add_component( - instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="document_writer") - ) - indexing_pipeline.connect("document_embedder", "document_writer") - indexing_pipeline.run({"document_embedder": {"documents": documents}}) - >>> {'document_writer': {'documents_written': 3}} - ``` - """ - - def __init__( - self, - *, - model: str = "bge-large", - batch_size: int = 32_768, - access_token: Secret = Secret.from_env_var("GRADIENT_ACCESS_TOKEN"), # noqa: B008 - workspace_id: Secret = Secret.from_env_var("GRADIENT_WORKSPACE_ID"), # noqa: B008 - host: Optional[str] = None, - progress_bar: bool = True, - ) -> None: - """ - Create a GradientDocumentEmbedder component. - - :param model: The name of the model to use. - :param batch_size: Update cycle for tqdm progress bar, default is to update every 32_768 docs. - :param access_token: The Gradient access token. - :param workspace_id: The Gradient workspace ID. - :param host: The Gradient host. By default, it uses [Gradient AI](https://api.gradient.ai/). - :param progress_bar: Whether to show a progress bar while embedding the documents. - """ - self._batch_size = batch_size - self._host = host - self._model_name = model - self._progress_bar = progress_bar - self._access_token = access_token - self._workspace_id = workspace_id - - self._gradient = Gradient( - access_token=access_token.resolve_value(), workspace_id=workspace_id.resolve_value(), host=host - ) - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self._model_name} - - def to_dict(self) -> dict: - """ - Serialize this component to a dictionary. - - :returns: - The serialized component as a dictionary. - """ - - return default_to_dict( - self, - model=self._model_name, - batch_size=self._batch_size, - host=self._host, - progress_bar=self._progress_bar, - access_token=self._access_token.to_dict(), - workspace_id=self._workspace_id.to_dict(), - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GradientDocumentEmbedder": - """ - Deserialize this component from a dictionary. - - :param data: The dictionary representation of this component. - :returns: - The deserialized component instance. - """ - deserialize_secrets_inplace(data["init_parameters"], keys=["access_token", "workspace_id"]) - return default_from_dict(cls, data) - - def warm_up(self) -> None: - """ - Initializes the component. - """ - if not hasattr(self, "_embedding_model"): - self._embedding_model = self._gradient.get_embeddings_model(slug=self._model_name) - - def _generate_embeddings(self, documents: List[Document], batch_size: int) -> List[List[float]]: - """ - Batches the documents and generates the embeddings. - """ - if self._progress_bar and tqdm_imported: - batches = [documents[i : i + batch_size] for i in range(0, len(documents), batch_size)] - progress_bar = tqdm - else: - # no progress bar - progress_bar = _alt_progress_bar # type: ignore - batches = [documents] - - embeddings = [] - for batch in progress_bar(batches): - response = self._embedding_model.embed(inputs=[{"input": doc.content} for doc in batch]) - embeddings.extend([e.embedding for e in response.embeddings]) - - return embeddings - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document]): - """ - Embed a list of Documents. - - The embedding of each Document is stored in the `embedding` field of the Document. - - :param documents: A list of Documents to embed. - :returns: - A dictionary with the following keys: - - `documents`: The embedded Documents. - - """ - if not isinstance(documents, list) or documents and any(not isinstance(doc, Document) for doc in documents): - msg = "GradientDocumentEmbedder expects a list of Documents as input.\ - In case you want to embed a list of strings, please use the GradientTextEmbedder." - raise TypeError(msg) - - if not hasattr(self, "_embedding_model"): - msg = "The embedding model has not been loaded. Please call warm_up() before running." - raise RuntimeError(msg) - - embeddings = self._generate_embeddings(documents=documents, batch_size=self._batch_size) - for doc, embedding in zip(documents, embeddings): - doc.embedding = embedding - - return {"documents": documents} diff --git a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py b/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py deleted file mode 100644 index 3bcbb4db6..000000000 --- a/integrations/gradient/src/haystack_integrations/components/embedders/gradient/gradient_text_embedder.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Any, Dict, List, Optional - -from gradientai import Gradient -from haystack import component, default_from_dict, default_to_dict -from haystack.utils import Secret, deserialize_secrets_inplace - - -@component -class GradientTextEmbedder: - """ - A component for embedding strings using models hosted on [Gradient AI](https://gradient.ai). - - Usage example: - ```python - from haystack_integrations.components.embedders.gradient import GradientTextEmbedder - from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever - from haystack.document_stores.in_memory import InMemoryDocumentStore - from haystack import Pipeline - - p = Pipeline() - p.add_component("text_embedder", GradientTextEmbedder(model="bge-large")) - p.add_component("retriever", InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())) - p.connect("text_embedder", "retriever") - p.run(data={"text_embedder": {"text":"You can embed me put I'll return no matching documents"}}) - >>> No Documents found with embeddings. Returning empty list. To generate embeddings, use a DocumentEmbedder. - >>> {'retriever': {'documents': []}} - ``` - """ - - def __init__( - self, - *, - model: str = "bge-large", - access_token: Secret = Secret.from_env_var("GRADIENT_ACCESS_TOKEN"), # noqa: B008 - workspace_id: Secret = Secret.from_env_var("GRADIENT_WORKSPACE_ID"), # noqa: B008 - host: Optional[str] = None, - ) -> None: - """ - Create a GradientTextEmbedder component. - - :param model: The name of the model to use. - :param access_token: The Gradient access token. - :param workspace_id: The Gradient workspace ID. - :param host: The Gradient host. By default, it uses [Gradient AI](https://api.gradient.ai/). - """ - self._host = host - self._model_name = model - self._access_token = access_token - self._workspace_id = workspace_id - - self._gradient = Gradient( - host=host, access_token=access_token.resolve_value(), workspace_id=workspace_id.resolve_value() - ) - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self._model_name} - - def to_dict(self) -> dict: - """ - Serialize this component to a dictionary. - - :returns: - The serialized component as a dictionary. - """ - return default_to_dict( - self, - model=self._model_name, - host=self._host, - access_token=self._access_token.to_dict(), - workspace_id=self._workspace_id.to_dict(), - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GradientTextEmbedder": - """ - Deserialize this component from a dictionary. - - :param data: The dictionary representation of this component. - :returns: - The deserialized component instance. - """ - deserialize_secrets_inplace(data["init_parameters"], keys=["access_token", "workspace_id"]) - return default_from_dict(cls, data) - - def warm_up(self) -> None: - """ - Initializes the component. - """ - if not hasattr(self, "_embedding_model"): - self._embedding_model = self._gradient.get_embeddings_model(slug=self._model_name) - - @component.output_types(embedding=List[float]) - def run(self, text: str): - """Generates an embedding for a single text.""" - if not isinstance(text, str): - msg = "GradientTextEmbedder expects a string as an input.\ - In case you want to embed a list of Documents, please use the GradientDocumentEmbedder." - raise TypeError(msg) - - if not hasattr(self, "_embedding_model"): - msg = "The embedding model has not been loaded. Please call warm_up() before running." - raise RuntimeError(msg) - - result = self._embedding_model.embed(inputs=[{"input": text}]) - - if (not result) or (result.embeddings is None) or (len(result.embeddings) == 0): - msg = "The embedding model did not return any embeddings." - raise RuntimeError(msg) - - return {"embedding": result.embeddings[0].embedding} diff --git a/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py b/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py deleted file mode 100644 index 71b39d309..000000000 --- a/integrations/gradient/src/haystack_integrations/components/generators/gradient/base.py +++ /dev/null @@ -1,144 +0,0 @@ -import logging -from typing import Any, Dict, List, Optional - -from gradientai import Gradient -from haystack import component, default_from_dict, default_to_dict -from haystack.utils import Secret, deserialize_secrets_inplace - -logger = logging.getLogger(__name__) - - -@component -class GradientGenerator: - """ - LLM Generator interfacing [Gradient AI](https://gradient.ai/). - - Queries the LLM using Gradient AI's SDK ('gradientai' package). - See [Gradient AI API](https://docs.gradient.ai/docs/sdk-quickstart) for more details. - - Usage example: - ```python - from haystack_integrations.components.generators.gradient import GradientGenerator - - llm = GradientGenerator(base_model_slug="llama2-7b-chat") - llm.warm_up() - print(llm.run(prompt="What is the meaning of life?")) - # Output: {'replies': ['42']} - ``` - """ - - def __init__( - self, - *, - access_token: Secret = Secret.from_env_var("GRADIENT_ACCESS_TOKEN"), # noqa: B008 - base_model_slug: Optional[str] = None, - host: Optional[str] = None, - max_generated_token_count: Optional[int] = None, - model_adapter_id: Optional[str] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - workspace_id: Secret = Secret.from_env_var("GRADIENT_WORKSPACE_ID"), # noqa: B008 - ) -> None: - """ - Create a GradientGenerator component. - - :param access_token: The Gradient access token as a `Secret`. If not provided it's read from the environment - variable `GRADIENT_ACCESS_TOKEN`. - :param base_model_slug: The base model slug to use. - :param host: The Gradient host. By default, it uses [Gradient AI](https://api.gradient.ai/). - :param max_generated_token_count: The maximum number of tokens to generate. - :param model_adapter_id: The model adapter ID to use. - :param temperature: The temperature to use. - :param top_k: The top k to use. - :param top_p: The top p to use. - :param workspace_id: The Gradient workspace ID as a `Secret`. If not provided it's read from the environment - variable `GRADIENT_WORKSPACE_ID`. - """ - self._access_token = access_token - self._base_model_slug = base_model_slug - self._host = host - self._max_generated_token_count = max_generated_token_count - self._model_adapter_id = model_adapter_id - self._temperature = temperature - self._top_k = top_k - self._top_p = top_p - self._workspace_id = workspace_id - - has_base_model_slug = base_model_slug is not None and base_model_slug != "" - has_model_adapter_id = model_adapter_id is not None and model_adapter_id != "" - - if not has_base_model_slug and not has_model_adapter_id: - msg = "Either base_model_slug or model_adapter_id must be provided." - raise ValueError(msg) - if has_base_model_slug and has_model_adapter_id: - msg = "Only one of base_model_slug or model_adapter_id must be provided." - raise ValueError(msg) - - if has_base_model_slug: - self._base_model_slug = base_model_slug - if has_model_adapter_id: - self._model_adapter_id = model_adapter_id - - self._gradient = Gradient( - access_token=access_token.resolve_value(), host=host, workspace_id=workspace_id.resolve_value() - ) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - - :returns: - The serialized component as a dictionary. - """ - return default_to_dict( - self, - access_token=self._access_token.to_dict(), - base_model_slug=self._base_model_slug, - host=self._host, - max_generated_token_count=self._max_generated_token_count, - model_adapter_id=self._model_adapter_id, - temperature=self._temperature, - top_k=self._top_k, - top_p=self._top_p, - workspace_id=self._workspace_id.to_dict(), - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GradientGenerator": - """ - Deserialize this component from a dictionary. - - :param data: The dictionary representation of this component. - :returns: - The deserialized component instance. - """ - - deserialize_secrets_inplace(data["init_parameters"], keys=["access_token", "workspace_id"]) - return default_from_dict(cls, data) - - def warm_up(self): - """ - Initializes the LLM model instance if it doesn't exist. - """ - if not hasattr(self, "_model"): - if isinstance(self._base_model_slug, str): - self._model = self._gradient.get_base_model(base_model_slug=self._base_model_slug) - if isinstance(self._model_adapter_id, str): - self._model = self._gradient.get_model_adapter(model_adapter_id=self._model_adapter_id) - - @component.output_types(replies=List[str]) - def run(self, prompt: str): - """ - Queries the LLM with the prompt to produce replies. - - :param prompt: The prompt to be sent to the generative model. - """ - resp = self._model.complete( - query=prompt, - max_generated_token_count=self._max_generated_token_count, - temperature=self._temperature, - top_k=self._top_k, - top_p=self._top_p, - ) - return {"replies": [resp.generated_output]} diff --git a/integrations/gradient/tests/__init__.py b/integrations/gradient/tests/__init__.py deleted file mode 100644 index e873bc332..000000000 --- a/integrations/gradient/tests/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/gradient/tests/test_gradient_document_embedder.py b/integrations/gradient/tests/test_gradient_document_embedder.py deleted file mode 100644 index 3bc739f3e..000000000 --- a/integrations/gradient/tests/test_gradient_document_embedder.py +++ /dev/null @@ -1,162 +0,0 @@ -from unittest.mock import MagicMock, NonCallableMagicMock - -import numpy as np -import pytest -from gradientai.openapi.client.models.generate_embedding_success import GenerateEmbeddingSuccess -from haystack import Document -from haystack.utils import Secret - -from haystack_integrations.components.embedders.gradient import GradientDocumentEmbedder - -access_token = "access_token" -workspace_id = "workspace_id" -model = "bge-large" - - -@pytest.fixture -def tokens_from_env(monkeypatch): - monkeypatch.setenv("GRADIENT_ACCESS_TOKEN", access_token) - monkeypatch.setenv("GRADIENT_WORKSPACE_ID", workspace_id) - - -class TestGradientDocumentEmbedder: - def test_init_from_env(self, tokens_from_env): - - embedder = GradientDocumentEmbedder() - assert embedder is not None - assert embedder._gradient.workspace_id == workspace_id - assert embedder._gradient._api_client.configuration.access_token == access_token - - def test_init_without_access_token(self, monkeypatch): - monkeypatch.delenv("GRADIENT_ACCESS_TOKEN", raising=False) - - with pytest.raises(ValueError): - GradientDocumentEmbedder() - - def test_init_without_workspace(self, monkeypatch): - monkeypatch.delenv("GRADIENT_WORKSPACE_ID", raising=False) - - with pytest.raises(ValueError): - GradientDocumentEmbedder() - - def test_init_from_params(self): - embedder = GradientDocumentEmbedder( - access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) - ) - assert embedder is not None - assert embedder._gradient.workspace_id == workspace_id - assert embedder._gradient._api_client.configuration.access_token == access_token - - def test_init_from_params_precedence(self, monkeypatch): - monkeypatch.setenv("GRADIENT_ACCESS_TOKEN", "env_access_token") - monkeypatch.setenv("GRADIENT_WORKSPACE_ID", "env_workspace_id") - - embedder = GradientDocumentEmbedder( - access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) - ) - assert embedder is not None - assert embedder._gradient.workspace_id == workspace_id - assert embedder._gradient._api_client.configuration.access_token == access_token - - def test_to_dict(self, tokens_from_env): - component = GradientDocumentEmbedder() - data = component.to_dict() - t = "haystack_integrations.components.embedders.gradient.gradient_document_embedder.GradientDocumentEmbedder" - assert data == { - "type": t, - "init_parameters": { - "access_token": {"env_vars": ["GRADIENT_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, - "batch_size": 32768, - "host": None, - "model": "bge-large", - "progress_bar": True, - "workspace_id": {"env_vars": ["GRADIENT_WORKSPACE_ID"], "strict": True, "type": "env_var"}, - }, - } - - def test_warmup(self, tokens_from_env): - embedder = GradientDocumentEmbedder() - embedder._gradient.get_embeddings_model = MagicMock() - embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") - - def test_warmup_doesnt_reload(self, tokens_from_env): - embedder = GradientDocumentEmbedder() - embedder._gradient.get_embeddings_model = MagicMock(default_return_value="fake model") - embedder.warm_up() - embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") - - def test_run_fail_if_not_warmed_up(self, tokens_from_env): - embedder = GradientDocumentEmbedder() - - with pytest.raises(RuntimeError, match="warm_up()"): - embedder.run(documents=[Document(content=f"document number {i}") for i in range(5)]) - - def test_run(self, tokens_from_env): - embedder = GradientDocumentEmbedder() - embedder._embedding_model = NonCallableMagicMock() - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": i} for i in range(5)] - ) - - documents = [Document(content=f"document number {i}") for i in range(5)] - - result = embedder.run(documents=documents) - - assert embedder._embedding_model.embed.call_count == 1 - assert isinstance(result["documents"], list) - assert len(result["documents"]) == len(documents) - for doc in result["documents"]: - assert isinstance(doc, Document) - assert isinstance(doc.embedding, list) - assert isinstance(doc.embedding[0], float) - - def test_run_batch(self, tokens_from_env): - embedder = GradientDocumentEmbedder() - embedder._embedding_model = NonCallableMagicMock() - - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": i} for i in range(110)] - ) - - documents = [Document(content=f"document number {i}") for i in range(110)] - - result = embedder.run(documents=documents) - - assert embedder._embedding_model.embed.call_count == 1 - assert isinstance(result["documents"], list) - assert len(result["documents"]) == len(documents) - for doc in result["documents"]: - assert isinstance(doc, Document) - assert isinstance(doc.embedding, list) - assert isinstance(doc.embedding[0], float) - - def test_run_custom_batch(self, tokens_from_env): - embedder = GradientDocumentEmbedder(batch_size=20) - embedder._embedding_model = NonCallableMagicMock() - - document_count = 101 - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": i} for i in range(document_count)] - ) - - documents = [Document(content=f"document number {i}") for i in range(document_count)] - - result = embedder.run(documents=documents) - - assert embedder._embedding_model.embed.call_count == 6 - assert isinstance(result["documents"], list) - assert len(result["documents"]) == len(documents) - for doc in result["documents"]: - assert isinstance(doc, Document) - assert isinstance(doc.embedding, list) - assert isinstance(doc.embedding[0], float) - - def test_run_empty(self, tokens_from_env): - embedder = GradientDocumentEmbedder() - embedder._embedding_model = NonCallableMagicMock() - - result = embedder.run(documents=[]) - - assert result["documents"] == [] diff --git a/integrations/gradient/tests/test_gradient_rag_pipelines.py b/integrations/gradient/tests/test_gradient_rag_pipelines.py deleted file mode 100644 index 89ec7cfb2..000000000 --- a/integrations/gradient/tests/test_gradient_rag_pipelines.py +++ /dev/null @@ -1,90 +0,0 @@ -import json -import os - -import pytest -from haystack import Document, Pipeline -from haystack.components.builders.answer_builder import AnswerBuilder -from haystack.components.builders.prompt_builder import PromptBuilder -from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever -from haystack.components.writers import DocumentWriter -from haystack.document_stores.in_memory import InMemoryDocumentStore - -from haystack_integrations.components.embedders.gradient import GradientDocumentEmbedder, GradientTextEmbedder -from haystack_integrations.components.generators.gradient import GradientGenerator - - -@pytest.mark.integration -@pytest.mark.skipif( - not os.environ.get("GRADIENT_ACCESS_TOKEN", None) or not os.environ.get("GRADIENT_WORKSPACE_ID", None), - reason="Export env variables called GRADIENT_ACCESS_TOKEN and GRADIENT_WORKSPACE_ID \ - containing the Gradient configuration settings to run this test.", -) -def test_gradient_embedding_retrieval_rag_pipeline(tmp_path): - # Create the RAG pipeline - prompt_template = """ - Given these documents, answer the question.\nDocuments: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} - \nQuestion: {{question}} - \nAnswer: - """ - - rag_pipeline = Pipeline() - embedder = GradientTextEmbedder() - rag_pipeline.add_component(instance=embedder, name="text_embedder") - rag_pipeline.add_component( - instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever" - ) - rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") - rag_pipeline.add_component(instance=GradientGenerator(base_model_slug="llama2-7b-chat"), name="llm") - rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") - rag_pipeline.connect("text_embedder", "retriever") - rag_pipeline.connect("retriever", "prompt_builder.documents") - rag_pipeline.connect("prompt_builder", "llm") - rag_pipeline.connect("llm.replies", "answer_builder.replies") - rag_pipeline.connect("retriever", "answer_builder.documents") - - # Draw the pipeline - rag_pipeline.draw(tmp_path / "test_gradient_embedding_rag_pipeline.png") - - # Serialize the pipeline to JSON - with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f: - json.dump(rag_pipeline.to_dict(), f) - - # Load the pipeline back - with open(tmp_path / "test_bm25_rag_pipeline.json") as f: - rag_pipeline = Pipeline.from_dict(json.load(f)) - - # Populate the document store - documents = [ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome."), - ] - document_store = rag_pipeline.get_component("retriever").document_store - indexing_pipeline = Pipeline() - indexing_pipeline.add_component(instance=GradientDocumentEmbedder(), name="document_embedder") - indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer") - indexing_pipeline.connect("document_embedder", "document_writer") - indexing_pipeline.run({"document_embedder": {"documents": documents}}) - - # Query and assert - questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"] - answers_spywords = ["Jean", "Mark", "Giorgio"] - - for question, spyword in zip(questions, answers_spywords): - result = rag_pipeline.run( - { - "text_embedder": {"text": question}, - "prompt_builder": {"question": question}, - "answer_builder": {"query": question}, - } - ) - - assert len(result["answer_builder"]["answers"]) == 1 - generated_answer = result["answer_builder"]["answers"][0] - assert spyword in generated_answer.data - assert generated_answer.query == question - assert hasattr(generated_answer, "documents") - assert hasattr(generated_answer, "meta") diff --git a/integrations/gradient/tests/test_gradient_text_embedder.py b/integrations/gradient/tests/test_gradient_text_embedder.py deleted file mode 100644 index b12587994..000000000 --- a/integrations/gradient/tests/test_gradient_text_embedder.py +++ /dev/null @@ -1,124 +0,0 @@ -from unittest.mock import MagicMock, NonCallableMagicMock - -import numpy as np -import pytest -from gradientai.openapi.client.models.generate_embedding_success import GenerateEmbeddingSuccess -from haystack.utils import Secret - -from haystack_integrations.components.embedders.gradient import GradientTextEmbedder - -access_token = "access_token" -workspace_id = "workspace_id" -model = "bge-large" - - -@pytest.fixture -def tokens_from_env(monkeypatch): - monkeypatch.setenv("GRADIENT_ACCESS_TOKEN", access_token) - monkeypatch.setenv("GRADIENT_WORKSPACE_ID", workspace_id) - - -class TestGradientTextEmbedder: - def test_init_from_env(self, tokens_from_env): - embedder = GradientTextEmbedder() - assert embedder is not None - assert embedder._gradient.workspace_id == workspace_id - assert embedder._gradient._api_client.configuration.access_token == access_token - - def test_init_without_access_token(self, monkeypatch): - monkeypatch.delenv("GRADIENT_ACCESS_TOKEN", raising=False) - - with pytest.raises(ValueError): - GradientTextEmbedder() - - def test_init_without_workspace(self, monkeypatch): - monkeypatch.delenv("GRADIENT_WORKSPACE_ID", raising=False) - - with pytest.raises(ValueError): - GradientTextEmbedder() - - def test_init_from_params(self): - embedder = GradientTextEmbedder( - access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) - ) - assert embedder is not None - assert embedder._gradient.workspace_id == workspace_id - assert embedder._gradient._api_client.configuration.access_token == access_token - - def test_init_from_params_precedence(self, monkeypatch): - monkeypatch.setenv("GRADIENT_ACCESS_TOKEN", "env_access_token") - monkeypatch.setenv("GRADIENT_WORKSPACE_ID", "env_workspace_id") - - embedder = GradientTextEmbedder( - access_token=Secret.from_token(access_token), workspace_id=Secret.from_token(workspace_id) - ) - assert embedder is not None - assert embedder._gradient.workspace_id == workspace_id - assert embedder._gradient._api_client.configuration.access_token == access_token - - def test_to_dict(self, tokens_from_env): - component = GradientTextEmbedder() - data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.embedders.gradient.gradient_text_embedder.GradientTextEmbedder", - "init_parameters": { - "access_token": {"env_vars": ["GRADIENT_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, - "host": None, - "model": "bge-large", - "workspace_id": {"env_vars": ["GRADIENT_WORKSPACE_ID"], "strict": True, "type": "env_var"}, - }, - } - - def test_warmup(self, tokens_from_env): - embedder = GradientTextEmbedder() - embedder._gradient.get_embeddings_model = MagicMock() - embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") - - def test_warmup_doesnt_reload(self, tokens_from_env): - embedder = GradientTextEmbedder() - embedder._gradient.get_embeddings_model = MagicMock(default_return_value="fake model") - embedder.warm_up() - embedder.warm_up() - embedder._gradient.get_embeddings_model.assert_called_once_with(slug="bge-large") - - def test_run_fail_if_not_warmed_up(self, tokens_from_env): - embedder = GradientTextEmbedder() - - with pytest.raises(RuntimeError, match="warm_up()"): - embedder.run(text="The food was delicious") - - def test_run_fail_when_no_embeddings_returned(self, tokens_from_env): - embedder = GradientTextEmbedder() - embedder._embedding_model = NonCallableMagicMock() - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess(embeddings=[]) - - with pytest.raises(RuntimeError): - _result = embedder.run(text="The food was delicious") - embedder._embedding_model.embed.assert_called_once_with(inputs=[{"input": "The food was delicious"}]) - - def test_run_empty_string(self, tokens_from_env): - embedder = GradientTextEmbedder() - embedder._embedding_model = NonCallableMagicMock() - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": 0}] - ) - - result = embedder.run(text="") - embedder._embedding_model.embed.assert_called_once_with(inputs=[{"input": ""}]) - - assert len(result["embedding"]) == 1024 # 1024 is the bge-large embedding size - assert all(isinstance(x, float) for x in result["embedding"]) - - def test_run(self, tokens_from_env): - embedder = GradientTextEmbedder() - embedder._embedding_model = NonCallableMagicMock() - embedder._embedding_model.embed.return_value = GenerateEmbeddingSuccess( - embeddings=[{"embedding": np.random.rand(1024).tolist(), "index": 0}] - ) - - result = embedder.run(text="The food was delicious") - embedder._embedding_model.embed.assert_called_once_with(inputs=[{"input": "The food was delicious"}]) - - assert len(result["embedding"]) == 1024 # 1024 is the bge-large embedding size - assert all(isinstance(x, float) for x in result["embedding"]) diff --git a/integrations/instructor_embedders/pydoc/config.yml b/integrations/instructor_embedders/pydoc/config.yml index dd5e38faa..a9ccc243c 100644 --- a/integrations/instructor_embedders/pydoc/config.yml +++ b/integrations/instructor_embedders/pydoc/config.yml @@ -16,7 +16,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Instructor embedders integration for Haystack category_slug: integrations-api title: Instructor Embedders diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index faf5d0216..0543a9a88 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -67,22 +67,17 @@ git_describe_command = 'git describe --tags --match="integrations/instructor_emb dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.test.matrix]] python = ["38", "39", "310", "311"] @@ -93,7 +88,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -104,7 +99,7 @@ parallel = true [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true +show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.ruff] diff --git a/integrations/jina/pydoc/config.yml b/integrations/jina/pydoc/config.yml index 67788d26d..8c7a241f6 100644 --- a/integrations/jina/pydoc/config.yml +++ b/integrations/jina/pydoc/config.yml @@ -1,7 +1,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../src] - modules: + modules: [ "haystack_integrations.components.embedders.jina.document_embedder", "haystack_integrations.components.embedders.jina.text_embedder", @@ -17,7 +17,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Jina integration for Haystack category_slug: integrations-api title: Jina diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index e724cac96..fa2fd50ed 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -43,12 +43,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/jina-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] @@ -60,7 +62,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -135,12 +137,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = ["haystack.*", "haystack_integrations.*", "pytest.*"] diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md new file mode 100644 index 000000000..2efa17a68 --- /dev/null +++ b/integrations/langfuse/CHANGELOG.md @@ -0,0 +1,19 @@ +# Changelog + +## [integrations/langfuse-v0.2.0] - 2024-06-18 + +## [integrations/langfuse-v0.1.0] - 2024-06-13 + +### 🚀 Features + +- Langfuse integration (#686) + +### 🐛 Bug Fixes + +- Performance optimizations and value error when streaming in langfuse (#798) + +### ⚙️ Miscellaneous Tasks + +- Use ChatMessage to_openai_format, update unit tests, pydocs (#725) + + diff --git a/integrations/langfuse/LICENSE.txt b/integrations/langfuse/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/langfuse/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/langfuse/README.md b/integrations/langfuse/README.md new file mode 100644 index 000000000..c39f9f864 --- /dev/null +++ b/integrations/langfuse/README.md @@ -0,0 +1,117 @@ +# langfuse-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/langfuse-haystack.svg)](https://pypi.org/project/langfuse-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/langfuse-haystack.svg)](https://pypi.org/project/langfuse-haystack) + +langfuse-haystack integrates tracing capabilities into [Haystack](https://github.com/deepset-ai/haystack) (2.x) pipelines using [Langfuse](https://langfuse.com/). This package enhances the visibility of pipeline runs by capturing comprehensive details of the execution traces, including API calls, context data, prompts, and more. Whether you're monitoring model performance, pinpointing areas for improvement, or creating datasets for fine-tuning and testing from your pipeline executions, langfuse-haystack is the right tool for you. +## Features + +- Easy integration with Haystack pipelines +- Capture the full context of the execution +- Track model usage and cost +- Collect user feedback +- Identify low-quality outputs +- Build fine-tuning and testing datasets + +## Installation + +To install langfuse-haystack, run the following command: + +```sh +pip install langfuse-haystack +``` + +## Usage + +To enable tracing in your Haystack pipeline, add the `LangfuseConnector` to your pipeline. +You also need to set the `LANGFUSE_SECRET_KEY` and `LANGFUSE_PUBLIC_KEY` environment variables in order to connect to Langfuse account. +You can get these keys by signing up for an account on the Langfuse website. + +⚠️ **Important:** To ensure proper tracing, always set environment variables before importing any Haystack components. This is crucial because Haystack initializes its internal tracing components during import. + +Here's the correct way to set up your script: + +```python +import os + +# Set environment variables first +os.environ["LANGFUSE_HOST"] = "https://cloud.langfuse.com" +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +# Then import Haystack components +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack import Pipeline + +from haystack_integrations.components.connectors.langfuse import LangfuseConnector + +# Rest of your code... +``` + +Alternatively, an even better practice is to set these environment variables in your shell before running the script. + + +Here's a full example: + +```python +import os + +os.environ["LANGFUSE_HOST"] = "https://cloud.langfuse.com" +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack import Pipeline + +from haystack_integrations.components.connectors.langfuse import LangfuseConnector + +if __name__ == "__main__": + pipe = Pipeline() + pipe.add_component("tracer", LangfuseConnector("Chat example")) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) + + pipe.connect("prompt_builder.prompt", "llm.messages") + + messages = [ + ChatMessage.from_system("Always respond in German even if some input data is in other languages."), + ChatMessage.from_user("Tell me about {{location}}"), + ] + + response = pipe.run( + data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}} + ) + print(response["llm"]["replies"][0]) + print(response["tracer"]["trace_url"]) +``` + +In this example, we add the `LangfuseConnector` to the pipeline with the name "tracer". Each run of the pipeline produces one trace viewable on the Langfuse website with a specific URL. The trace captures the entire execution context, including the prompts, completions, and metadata. + +## Trace Visualization + +Langfuse provides a user-friendly interface to visualize and analyze the traces generated by your Haystack pipeline. Login into your Langfuse account and navigate to the trace URL to view the trace details. + +## Contributing + +`hatch` is the best way to interact with this project. To install it, run: +```sh +pip install hatch +``` + +With `hatch` installed, run all the tests: +``` +hatch run test +``` + +Run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + +## License + +`langfuse-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/langfuse/example/basic_rag.py b/integrations/langfuse/example/basic_rag.py new file mode 100644 index 000000000..492a14d49 --- /dev/null +++ b/integrations/langfuse/example/basic_rag.py @@ -0,0 +1,65 @@ +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +from datasets import load_dataset +from haystack import Document, Pipeline +from haystack.components.builders import PromptBuilder +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.generators import OpenAIGenerator +from haystack.components.retrievers import InMemoryEmbeddingRetriever +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.connectors.langfuse import LangfuseConnector + + +def get_pipeline(document_store: InMemoryDocumentStore): + retriever = InMemoryEmbeddingRetriever(document_store=document_store, top_k=2) + + template = """ + Given the following information, answer the question. + + Context: + {% for document in documents %} + {{ document.content }} + {% endfor %} + + Question: {{question}} + Answer: + """ + + prompt_builder = PromptBuilder(template=template) + + basic_rag_pipeline = Pipeline() + # Add components to your pipeline + basic_rag_pipeline.add_component("tracer", LangfuseConnector("Basic RAG Pipeline")) + basic_rag_pipeline.add_component( + "text_embedder", SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") + ) + basic_rag_pipeline.add_component("retriever", retriever) + basic_rag_pipeline.add_component("prompt_builder", prompt_builder) + basic_rag_pipeline.add_component("llm", OpenAIGenerator(model="gpt-3.5-turbo", generation_kwargs={"n": 2})) + + # Now, connect the components to each other + # NOTE: the tracer component doesn't need to be connected to anything in order to work + basic_rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + basic_rag_pipeline.connect("retriever", "prompt_builder.documents") + basic_rag_pipeline.connect("prompt_builder", "llm") + + return basic_rag_pipeline + + +if __name__ == "__main__": + document_store = InMemoryDocumentStore() + dataset = load_dataset("bilgeyucel/seven-wonders", split="train") + embedder = SentenceTransformersDocumentEmbedder("sentence-transformers/all-MiniLM-L6-v2") + embedder.warm_up() + docs_with_embeddings = embedder.run([Document(**ds) for ds in dataset]).get("documents") or [] # type: ignore + document_store.write_documents(docs_with_embeddings) + + pipeline = get_pipeline(document_store) + question = "What does Rhodes Statue look like?" + response = pipeline.run({"text_embedder": {"text": question}, "prompt_builder": {"question": question}}) + + print(response["llm"]["replies"][0]) + print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py new file mode 100644 index 000000000..443d65a13 --- /dev/null +++ b/integrations/langfuse/example/chat.py @@ -0,0 +1,27 @@ +import os + +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack_integrations.components.connectors.langfuse import LangfuseConnector + +if __name__ == "__main__": + + pipe = Pipeline() + pipe.add_component("tracer", LangfuseConnector("Chat example")) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) + + pipe.connect("prompt_builder.prompt", "llm.messages") + + messages = [ + ChatMessage.from_system("Always respond in German even if some input data is in other languages."), + ChatMessage.from_user("Tell me about {{location}}"), + ] + + response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) + print(response["llm"]["replies"][0]) + print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/example/requirements.txt b/integrations/langfuse/example/requirements.txt new file mode 100644 index 000000000..3db2429f2 --- /dev/null +++ b/integrations/langfuse/example/requirements.txt @@ -0,0 +1,3 @@ +langfuse-haystack +datasets +sentence-transformers \ No newline at end of file diff --git a/integrations/gradient/pydoc/config.yml b/integrations/langfuse/pydoc/config.yml similarity index 55% rename from integrations/gradient/pydoc/config.yml rename to integrations/langfuse/pydoc/config.yml index a0ec5f72d..c08bb35c3 100644 --- a/integrations/gradient/pydoc/config.yml +++ b/integrations/langfuse/pydoc/config.yml @@ -2,9 +2,8 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../src] modules: [ - "haystack_integrations.components.embedders.gradient.gradient_document_embedder", - "haystack_integrations.components.embedders.gradient.gradient_text_embedder", - "haystack_integrations.components.generators.gradient.base", + "haystack_integrations.components.connectors.langfuse.langfuse_connector", + "haystack_integrations.tracing.langfuse.tracer", ] ignore_when_discovered: ["__init__"] processors: @@ -16,16 +15,16 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer - excerpt: Cohere integration for Haystack + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Langfuse integration for Haystack category_slug: integrations-api - title: Gradient - slug: integrations-gradient - order: 110 + title: langfuse + slug: integrations-langfuse + order: 136 markdown: descriptive_class_title: false classdef_code_block: false descriptive_module_title: true add_method_class_prefix: true add_member_class_prefix: false - filename: _readme_gradient.md \ No newline at end of file + filename: _readme_langfuse.md diff --git a/integrations/gradient/pyproject.toml b/integrations/langfuse/pyproject.toml similarity index 57% rename from integrations/gradient/pyproject.toml rename to integrations/langfuse/pyproject.toml index 0ed7c66d2..cf7b85b64 100644 --- a/integrations/gradient/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -3,96 +3,76 @@ requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] -name = "gradient-haystack" +name = "langfuse-haystack" dynamic = ["version"] description = '' readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "Mateusz Haligowski", email = "contact@gradient.ai" }, - { name = "Michael Feil", email = "contact@gradient.ai" }, - { name = "Hayden Wilson", email = "contact@gradient.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ - "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "gradientai>=1.4.0", -] -optional-dependencies = { tqdm = ["tqdm"] } +dependencies = ["haystack-ai>=2.1.0", "langfuse"] [project.urls] -Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/gradient#readme" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/langfuse#readme" Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" -Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/gradient" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/langfuse" [tool.hatch.build.targets.wheel] packages = ["src/haystack_integrations"] [tool.hatch.version] source = "vcs" -tag-pattern = 'integrations\/gradient-v(?P.*)' +tag-pattern = 'integrations\/langfuse-v(?P.*)' + [tool.hatch.version.raw-options] root = "../.." -git_describe_command = 'git describe --tags --match="integrations/gradient-v[0-9]*"' +git_describe_command = 'git describe --tags --match="integrations/langfuse-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + + [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] + [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.hatch.metadata] +allow-direct-references = true [tool.black] target-version = ["py38"] @@ -102,7 +82,7 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -select = [ +lint.select = [ "A", "ARG", "B", @@ -111,7 +91,6 @@ select = [ "E", "EM", "F", - "FBT", "I", "ICN", "ISC", @@ -129,30 +108,41 @@ select = [ "W", "YTT", ] -ignore = [ + +lint.ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", -] -unfixable = [ + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Asserts + "S101", +] +lint.unfixable = [ # Don't touch unused imports "F401", ] +extend-exclude = ["tests", "example"] -[tool.ruff.isort] -known-first-party = ["haystack_integrations"] +[tool.ruff.lint.isort] +known-first-party = ["src"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] -# Tests can use magic values, assertions, relative imports, and unused fixtures -"tests/**/*" = ["PLR2004", "S101", "TID252", "ARG002"] +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] +# Examples can print their output +"examples/**" = ["T201"] +"tests/**" = ["T201"] [tool.coverage.run] source = ["haystack_integrations"] @@ -162,20 +152,20 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] - +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ - "gradientai.*", + "langfuse.*", "haystack.*", "haystack_integrations.*", "pytest.*", "numpy.*", ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = ["integration: integration tests"] +log_cli = true diff --git a/integrations/gradient/src/haystack_integrations/components/generators/gradient/__init__.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/__init__.py similarity index 57% rename from integrations/gradient/src/haystack_integrations/components/generators/gradient/__init__.py rename to integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/__init__.py index a9d7cd421..c17a196ce 100644 --- a/integrations/gradient/src/haystack_integrations/components/generators/gradient/__init__.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .base import GradientGenerator +from .langfuse_connector import LangfuseConnector -__all__ = ["GradientGenerator"] +__all__ = ["LangfuseConnector"] diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py new file mode 100644 index 000000000..51703823e --- /dev/null +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -0,0 +1,116 @@ +from haystack import component, tracing +from haystack_integrations.tracing.langfuse import LangfuseTracer + +from langfuse import Langfuse + + +@component +class LangfuseConnector: + """ + LangfuseConnector connects Haystack LLM framework with Langfuse in order to enable the tracing of operations + and data flow within various components of a pipeline. + + Simply add this component to your pipeline, but *do not* connect it to any other component. The LangfuseConnector + will automatically trace the operations and data flow within the pipeline. + + Note that you need to set the `LANGFUSE_SECRET_KEY` and `LANGFUSE_PUBLIC_KEY` environment variables in order + to use this component. The `LANGFUSE_SECRET_KEY` and `LANGFUSE_PUBLIC_KEY` are the secret and public keys provided + by Langfuse. You can get these keys by signing up for an account on the Langfuse website. + + In addition, you need to set the `HAYSTACK_CONTENT_TRACING_ENABLED` environment variable to `true` in order to + enable Haystack tracing in your pipeline. + + Lastly, you may disable flushing the data after each component by setting the `HAYSTACK_LANGFUSE_ENFORCE_FLUSH` + environent variable to `false`. By default, the data is flushed after each component and blocks the thread until + the data is sent to Langfuse. **Caution**: Disabling this feature may result in data loss if the program crashes + before the data is sent to Langfuse. Make sure you will call langfuse.flush() explicitly before the program exits. + E.g. by using tracer.actual_tracer.flush(): + + ```python + from haystack.tracing import tracer + + try: + # your code here + finally: + tracer.actual_tracer.flush() + ``` + or in FastAPI by defining a shutdown event handler: + ```python + from haystack.tracing import tracer + + # ... + + + @app.on_event("shutdown") + async def shutdown_event(): + tracer.actual_tracer.flush() + ``` + + Here is an example of how to use it: + + ```python + import os + + os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + + from haystack import Pipeline + from haystack.components.builders import ChatPromptBuilder + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + from haystack_integrations.components.connectors.langfuse import ( + LangfuseConnector, + ) + + if __name__ == "__main__": + pipe = Pipeline() + pipe.add_component("tracer", LangfuseConnector("Chat example")) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) + + pipe.connect("prompt_builder.prompt", "llm.messages") + + messages = [ + ChatMessage.from_system( + "Always respond in German even if some input data is in other languages." + ), + ChatMessage.from_user("Tell me about {{location}}"), + ] + + response = pipe.run( + data={ + "prompt_builder": { + "template_variables": {"location": "Berlin"}, + "template": messages, + } + } + ) + print(response["llm"]["replies"][0]) + print(response["tracer"]["trace_url"]) + ``` + + """ + + def __init__(self, name: str, public: bool = False): + """ + Initialize the LangfuseConnector component. + + :param name: The name of the pipeline or component. This name will be used to identify the tracing run on the + Langfuse dashboard. + :param public: Whether the tracing data should be public or private. If set to `True`, the tracing data will be + publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private and + only accessible to the Langfuse account owner. The default is `False`. + """ + self.name = name + self.tracer = LangfuseTracer(tracer=Langfuse(), name=name, public=public) + tracing.enable_tracing(self.tracer) + + @component.output_types(name=str, trace_url=str) + def run(self): + """ + Runs the LangfuseConnector component. + + :returns: A dictionary with the following keys: + - `name`: The name of the tracing component. + - `trace_url`: The URL to the tracing data. + """ + return {"name": self.name, "trace_url": self.tracer.get_trace_url()} diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py new file mode 100644 index 000000000..e7331852d --- /dev/null +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .tracer import LangfuseTracer + +__all__ = ["LangfuseTracer"] diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py new file mode 100644 index 000000000..4bf0da2f8 --- /dev/null +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -0,0 +1,174 @@ +import contextlib +import os +from typing import Any, Dict, Iterator, Optional, Union + +from haystack.dataclasses import ChatMessage +from haystack.tracing import Span, Tracer, tracer +from haystack.tracing import utils as tracing_utils + +import langfuse + +HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" +_SUPPORTED_GENERATORS = ["AzureOpenAIGenerator", "OpenAIGenerator"] +_SUPPORTED_CHAT_GENERATORS = ["AzureOpenAIChatGenerator", "OpenAIChatGenerator"] +_ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS + + +class LangfuseSpan(Span): + """ + Internal class representing a bridge between the Haystack span tracing API and Langfuse. + """ + + def __init__(self, span: "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]") -> None: + """ + Initialize a LangfuseSpan instance. + + :param span: The span instance managed by Langfuse. + """ + self._span = span + # locally cache tags + self._data: Dict[str, Any] = {} + + def set_tag(self, key: str, value: Any) -> None: + """ + Set a generic tag for this span. + + :param key: The tag key. + :param value: The tag value. + """ + coerced_value = tracing_utils.coerce_tag_value(value) + self._span.update(metadata={key: coerced_value}) + self._data[key] = value + + def set_content_tag(self, key: str, value: Any) -> None: + """ + Set a content-specific tag for this span. + + :param key: The content tag key. + :param value: The content tag value. + """ + if not tracer.is_content_tracing_enabled: + return + if key.endswith(".input"): + if "messages" in value: + messages = [m.to_openai_format() for m in value["messages"]] + self._span.update(input=messages) + else: + self._span.update(input=value) + elif key.endswith(".output"): + if "replies" in value: + if all(isinstance(r, ChatMessage) for r in value["replies"]): + replies = [m.to_openai_format() for m in value["replies"]] + else: + replies = value["replies"] + self._span.update(output=replies) + else: + self._span.update(output=value) + + self._data[key] = value + + def raw_span(self) -> Any: + """ + Return the underlying span instance. + + :return: The Langfuse span instance. + """ + return self._span + + def get_correlation_data_for_logs(self) -> Dict[str, Any]: + return {} + + +class LangfuseTracer(Tracer): + """ + Internal class representing a bridge between the Haystack tracer and Langfuse. + """ + + def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: bool = False) -> None: + """ + Initialize a LangfuseTracer instance. + + :param tracer: The Langfuse tracer instance. + :param name: The name of the pipeline or component. This name will be used to identify the tracing run on the + Langfuse dashboard. + :param public: Whether the tracing data should be public or private. If set to `True`, the tracing data will + be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private + and only accessible to the Langfuse account owner. + """ + self._tracer = tracer + self._context: list[LangfuseSpan] = [] + self._name = name + self._public = public + self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" + + @contextlib.contextmanager + def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]: + """ + Start and manage a new trace span. + :param operation_name: The name of the operation. + :param tags: A dictionary of tags to attach to the span. + :return: A context manager yielding the span. + """ + tags = tags or {} + span_name = tags.get("haystack.component.name", operation_name) + + if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS: + span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name)) + else: + span = LangfuseSpan(self.current_span().raw_span().span(name=span_name)) + + self._context.append(span) + span.set_tags(tags) + + yield span + + if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS: + meta = span._data.get("haystack.component.output", {}).get("meta") + if meta: + # Haystack returns one meta dict for each message, but the 'usage' value + # is always the same, let's just pick the first item + m = meta[0] + span._span.update(usage=m.get("usage") or None, model=m.get("model")) + elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS: + replies = span._data.get("haystack.component.output", {}).get("replies") + if replies: + meta = replies[0].meta + span._span.update(usage=meta.get("usage") or None, model=meta.get("model")) + + pipeline_input = tags.get("haystack.pipeline.input_data", None) + if pipeline_input: + span._span.update(input=tags["haystack.pipeline.input_data"]) + pipeline_output = tags.get("haystack.pipeline.output_data", None) + if pipeline_output: + span._span.update(output=tags["haystack.pipeline.output_data"]) + + span.raw_span().end() + self._context.pop() + + if len(self._context) == 1: + # The root span has to be a trace, which need to be removed from the context after the pipeline run + self._context.pop() + + if self.enforce_flush: + self.flush() + + def flush(self): + self._tracer.flush() + + def current_span(self) -> Span: + """ + Return the currently active span. + + :return: The currently active span. + """ + if not self._context: + # The root span has to be a trace + self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public))) + return self._context[-1] + + def get_trace_url(self) -> str: + """ + Return the URL to the tracing data. + :return: The URL to the tracing data. + """ + return self._tracer.get_trace_url() diff --git a/integrations/langfuse/tests/__init__.py b/integrations/langfuse/tests/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/langfuse/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/langfuse/tests/test_langfuse_span.py b/integrations/langfuse/tests/test_langfuse_span.py new file mode 100644 index 000000000..a5a5f2c13 --- /dev/null +++ b/integrations/langfuse/tests/test_langfuse_span.py @@ -0,0 +1,65 @@ +import os + +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +from unittest.mock import Mock +from haystack.dataclasses import ChatMessage +from haystack_integrations.tracing.langfuse.tracer import LangfuseSpan + + +class TestLangfuseSpan: + + # LangfuseSpan can be initialized with a span object + def test_initialized_with_span_object(self): + mock_span = Mock() + span = LangfuseSpan(mock_span) + assert span.raw_span() == mock_span + + # set_tag method can update metadata of the span object + def test_set_tag_updates_metadata(self): + mock_span = Mock() + span = LangfuseSpan(mock_span) + + span.set_tag("key", "value") + mock_span.update.assert_called_once_with(metadata={"key": "value"}) + assert span._data["key"] == "value" + + # set_content_tag method can update input and output of the span object + def test_set_content_tag_updates_input_and_output(self): + mock_span = Mock() + + span = LangfuseSpan(mock_span) + span.set_content_tag("input_key", "input_value") + assert span._data["input_key"] == "input_value" + + mock_span.reset_mock() + span.set_content_tag("output_key", "output_value") + assert span._data["output_key"] == "output_value" + + # set_content_tag method can update input and output of the span object with messages/replies + def test_set_content_tag_updates_input_and_output_with_messages(self): + mock_span = Mock() + + # test message input + span = LangfuseSpan(mock_span) + span.set_content_tag("key.input", {"messages": [ChatMessage.from_user("message")]}) + assert mock_span.update.call_count == 1 + # check we converted ChatMessage to OpenAI format + assert mock_span.update.call_args_list[0][1] == {"input": [{"role": "user", "content": "message"}]} + assert span._data["key.input"] == {"messages": [ChatMessage.from_user("message")]} + + # test replies ChatMessage list + mock_span.reset_mock() + span.set_content_tag("key.output", {"replies": [ChatMessage.from_system("reply")]}) + assert mock_span.update.call_count == 1 + # check we converted ChatMessage to OpenAI format + assert mock_span.update.call_args_list[0][1] == {"output": [{"role": "system", "content": "reply"}]} + assert span._data["key.output"] == {"replies": [ChatMessage.from_system("reply")]} + + # test replies string list + mock_span.reset_mock() + span.set_content_tag("key.output", {"replies": ["reply1", "reply2"]}) + assert mock_span.update.call_count == 1 + # check we handle properly string list replies + assert mock_span.update.call_args_list[0][1] == {"output": ["reply1", "reply2"]} + assert span._data["key.output"] == {"replies": ["reply1", "reply2"]} diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py new file mode 100644 index 000000000..241581a72 --- /dev/null +++ b/integrations/langfuse/tests/test_tracer.py @@ -0,0 +1,114 @@ +import os +from unittest.mock import Mock, MagicMock, patch + +from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer + + +class TestLangfuseTracer: + + # LangfuseTracer can be initialized with a Langfuse instance, a name and a boolean value for public. + def test_initialization(self): + langfuse_instance = Mock() + tracer = LangfuseTracer(tracer=langfuse_instance, name="Haystack", public=True) + assert tracer._tracer == langfuse_instance + assert tracer._context == [] + assert tracer._name == "Haystack" + assert tracer._public + + # check that the trace method is called on the tracer instance with the provided operation name and tags + # check that the span is added to the context and removed after the context manager exits + def test_create_new_span(self): + mock_raw_span = MagicMock() + mock_raw_span.operation_name = "operation_name" + mock_raw_span.metadata = {"tag1": "value1", "tag2": "value2"} + + with patch("haystack_integrations.tracing.langfuse.tracer.LangfuseSpan") as MockLangfuseSpan: + mock_span_instance = MockLangfuseSpan.return_value + mock_span_instance.raw_span.return_value = mock_raw_span + + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_span_instance + + mock_tracer = MagicMock() + mock_tracer.trace.return_value = mock_context_manager + + tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False) + + with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span: + assert len(tracer._context) == 2, "The trace span should have been added to the the root context span" + assert span.raw_span().operation_name == "operation_name" + assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"} + + assert ( + len(tracer._context) == 0 + ), "The trace span should have been popped, and the root span is closed as well" + + # check that update method is called on the span instance with the provided key value pairs + def test_update_span_with_pipeline_input_output_data(self): + class MockTracer: + + def trace(self, name, **kwargs): + return MockSpan() + + def flush(self): + pass + + class MockSpan: + def __init__(self): + self._data = {} + self._span = self + self.operation_name = "operation_name" + + def raw_span(self): + return self + + def span(self, name=None): + # assert correct operation name passed to the span + assert name == "operation_name" + return self + + def update(self, **kwargs): + self._data.update(kwargs) + + def generation(self, name=None): + return self + + def end(self): + pass + + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: + assert span.raw_span()._data["metadata"] == {"haystack.pipeline.input_data": "hello"} + + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.output_data": "bye"}) as span: + assert span.raw_span()._data["metadata"] == {"haystack.pipeline.output_data": "bye"} + + def test_update_span_gets_flushed_by_default(self): + tracer_mock = Mock() + + tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False) + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: + pass + + tracer_mock.flush.assert_called_once() + + def test_update_span_flush_disable(self, monkeypatch): + monkeypatch.setenv("HAYSTACK_LANGFUSE_ENFORCE_FLUSH", "false") + tracer_mock = Mock() + + from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer + + tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False) + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: + pass + + tracer_mock.flush.assert_not_called() + + def test_context_is_empty_after_tracing(self): + tracer_mock = Mock() + + tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False) + with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: + pass + + assert tracer._context == [] diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py new file mode 100644 index 000000000..111d89dfd --- /dev/null +++ b/integrations/langfuse/tests/test_tracing.py @@ -0,0 +1,55 @@ +import os + +# don't remove (or move) this env var setting from here, it's needed to turn tracing on +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +from urllib.parse import urlparse + +import pytest +import requests + +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from requests.auth import HTTPBasicAuth + +from haystack_integrations.components.connectors.langfuse import LangfuseConnector + + +@pytest.mark.integration +@pytest.mark.skipif( + not os.environ.get("LANGFUSE_SECRET_KEY", None) and not os.environ.get("LANGFUSE_PUBLIC_KEY", None), + reason="Export an env var called LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY containing Langfuse credentials.", +) +def test_tracing_integration(): + + pipe = Pipeline() + pipe.add_component("tracer", LangfuseConnector(name="Chat example", public=True)) # public so anyone can verify run + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) + + pipe.connect("prompt_builder.prompt", "llm.messages") + + messages = [ + ChatMessage.from_system("Always respond in German even if some input data is in other languages."), + ChatMessage.from_user("Tell me about {{location}}"), + ] + + response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) + assert "Berlin" in response["llm"]["replies"][0].content + assert response["tracer"]["trace_url"] + url = "https://cloud.langfuse.com/api/public/traces/" + trace_url = response["tracer"]["trace_url"] + parsed_url = urlparse(trace_url) + # trace id is the last part of the path (after the last '/') + uuid = os.path.basename(parsed_url.path) + try: + # GET request with Basic Authentication on the Langfuse API + response = requests.get( + url + uuid, auth=HTTPBasicAuth(os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get("LANGFUSE_SECRET_KEY")) + ) + + assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + except requests.exceptions.RequestException as e: + assert False, f"Failed to retrieve data from Langfuse API: {e}" diff --git a/integrations/llama_cpp/CHANGELOG.md b/integrations/llama_cpp/CHANGELOG.md new file mode 100644 index 000000000..ea4c05e4d --- /dev/null +++ b/integrations/llama_cpp/CHANGELOG.md @@ -0,0 +1,50 @@ +# Changelog + +## [integrations/llama_cpp-v0.4.1] - 2024-08-08 + +### 🐛 Bug Fixes + +- Replace DynamicChatPromptBuilder with ChatPromptBuilder (#940) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Pin `llama-cpp-python>=0.2.87` (#955) + +## [integrations/llama_cpp-v0.4.0] - 2024-05-13 + +### 🐛 Bug Fixes + +- Fix commit (#436) + + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- [**breaking**] Rename model_path to model in the Llama.cpp integration (#243) + +### Llama.cpp + +- Generate api docs (#353) + +## [integrations/llama_cpp-v0.2.1] - 2024-01-18 + +## [integrations/llama_cpp-v0.2.0] - 2024-01-17 + +## [integrations/llama_cpp-v0.1.0] - 2024-01-09 + +### 🚀 Features + +- Add Llama.cpp Generator (#179) + + diff --git a/integrations/llama_cpp/pydoc/config.yml b/integrations/llama_cpp/pydoc/config.yml index 98068e672..a2b46c099 100644 --- a/integrations/llama_cpp/pydoc/config.yml +++ b/integrations/llama_cpp/pydoc/config.yml @@ -14,7 +14,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Llama.cpp integration for Haystack category_slug: integrations-api title: Llama.cpp diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index 563af391d..2efb15d53 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -15,7 +15,7 @@ authors = [ { name = "Ashwin Mathur", email = "" }, ] classifiers = [ - "License :: OSI Approved :: Apache Software License", + "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3.8", @@ -26,10 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "llama-cpp-python" -] +dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme" @@ -51,49 +48,31 @@ git_describe_command = 'git describe --tags --match="integrations/llama_cpp-v[0- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", + "transformers[sentencepiece]", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -139,9 +118,15 @@ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -166,27 +151,16 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] markers = [ "integration: marks tests as slow (deselect with '-m \"not integration\"')", ] -addopts = [ - "--import-mode=importlib", -] +addopts = ["--import-mode=importlib"] [[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pytest.*", - "llama_cpp.*" -] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "llama_cpp.*"] ignore_missing_imports = true diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py index cac9235bd..10b20d363 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import LlamaCppChatGenerator from .generator import LlamaCppGenerator -__all__ = ["LlamaCppGenerator"] +__all__ = ["LlamaCppGenerator", "LlamaCppChatGenerator"] diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py new file mode 100644 index 000000000..d43700215 --- /dev/null +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -0,0 +1,139 @@ +import logging +from typing import Any, Dict, List, Optional + +from haystack import component +from haystack.dataclasses import ChatMessage, ChatRole +from llama_cpp import Llama +from llama_cpp.llama_tokenizer import LlamaHFTokenizer + +logger = logging.getLogger(__name__) + + +def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]: + """ + Convert a message to the format expected by Llama.cpp. + :returns: A dictionary with the following keys: + - `role` + - `content` + - `name` (optional) + """ + formatted_msg = {"role": message.role.value, "content": message.content} + if message.name: + formatted_msg["name"] = message.name + + return formatted_msg + + +@component +class LlamaCppChatGenerator: + """ + Provides an interface to generate text using LLM via llama.cpp. + + [llama.cpp](https://github.com/ggerganov/llama.cpp) is a project written in C/C++ for efficient inference of LLMs. + It employs the quantized GGUF format, suitable for running these models on standard machines (even without GPUs). + + Usage example: + ```python + from haystack_integrations.components.generators.llama_cpp import LlamaCppChatGenerator + user_message = [ChatMessage.from_user("Who is the best American actor?")] + generator = LlamaCppGenerator(model="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512) + + print(generator.run(user_message, generation_kwargs={"max_tokens": 128})) + # {"replies": [ChatMessage(content="John Cusack", role=, name=None, meta={...}]} + ``` + """ + + def __init__( + self, + model: str, + n_ctx: Optional[int] = 0, + n_batch: Optional[int] = 512, + model_kwargs: Optional[Dict[str, Any]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf". + If the model path is also specified in the `model_kwargs`, this parameter will be ignored. + :param n_ctx: The number of tokens in the context. When set to 0, the context will be taken from the model. + :param n_batch: Prompt processing maximum batch size. + :param model_kwargs: Dictionary containing keyword arguments used to initialize the LLM for text generation. + These keyword arguments provide fine-grained control over the model loading. + In case of duplication, these kwargs override `model`, `n_ctx`, and `n_batch` init parameters. + For more information on the available kwargs, see + [llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__). + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + For more information on the available kwargs, see + [llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion). + """ + + model_kwargs = model_kwargs or {} + generation_kwargs = generation_kwargs or {} + + if "hf_tokenizer_path" in model_kwargs: + tokenizer = LlamaHFTokenizer.from_pretrained(model_kwargs["hf_tokenizer_path"]) + model_kwargs["tokenizer"] = tokenizer + + # check if the model_kwargs contain the essential parameters + # otherwise, populate them with values from init parameters + model_kwargs.setdefault("model_path", model) + model_kwargs.setdefault("n_ctx", n_ctx) + model_kwargs.setdefault("n_batch", n_batch) + + self.model_path = model + self.n_ctx = n_ctx + self.n_batch = n_batch + self.model_kwargs = model_kwargs + self.generation_kwargs = generation_kwargs + self.model = None + + def warm_up(self): + if self.model is None: + self.model = Llama(**self.model_kwargs) + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Run the text generation model on the given list of ChatMessages. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + For more information on the available kwargs, see + [llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion). + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + if self.model is None: + error_msg = "The model has not been loaded. Please call warm_up() before running." + raise RuntimeError(error_msg) + + if not messages: + return {"replies": []} + + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages] + + response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) + replies = [ + ChatMessage( + content=choice["message"]["content"], + role=ChatRole[choice["message"]["role"].upper()], + name=None, + meta={ + "response_id": response["id"], + "model": response["model"], + "created": response["created"], + "index": choice["index"], + "finish_reason": choice["finish_reason"], + "usage": response["usage"], + }, + ) + for choice in response["choices"] + ] + + for reply, choice in zip(replies, response["choices"]): + tool_calls = choice.get("message", {}).get("tool_calls", []) + if tool_calls: + reply.meta["tool_calls"] = tool_calls + reply.name = tool_calls[0]["function"]["name"] if tool_calls else None + return {"replies": replies} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py new file mode 100644 index 000000000..7bd6ef122 --- /dev/null +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -0,0 +1,498 @@ +import json +import os +import urllib.request +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from haystack import Document, Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.generators.llama_cpp.chat.chat_generator import ( + LlamaCppChatGenerator, + _convert_message_to_llamacpp_format, +) + + +@pytest.fixture +def model_path(): + return Path(__file__).parent / "models" + + +def download_file(file_link, filename, capsys): + # Checks if the file already exists before downloading + if not os.path.isfile(filename): + urllib.request.urlretrieve(file_link, filename) # noqa: S310 + with capsys.disabled(): + print("\nModel file downloaded successfully.") + else: + with capsys.disabled(): + print("\nModel file already exists.") + + +def test_convert_message_to_llamacpp_format(): + message = ChatMessage.from_system("You are good assistant") + assert _convert_message_to_llamacpp_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_function("Function call", "function_name") + assert _convert_message_to_llamacpp_format(message) == { + "role": "function", + "content": "Function call", + "name": "function_name", + } + + +class TestLlamaCppChatGenerator: + @pytest.fixture + def generator(self, model_path, capsys): + gguf_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) + filename = "openchat-3.5-1210.Q3_K_S.gguf" + + # Download GGUF model from HuggingFace + download_file(gguf_model_path, str(model_path / filename), capsys) + + model_path = str(model_path / filename) + generator = LlamaCppChatGenerator(model=model_path, n_ctx=8192, n_batch=512) + generator.warm_up() + return generator + + @pytest.fixture + def generator_mock(self): + mock_model = MagicMock() + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=2048, n_batch=512) + generator.model = mock_model + return generator, mock_model + + def test_default_init(self): + """ + Test default initialization parameters. + """ + generator = LlamaCppChatGenerator(model="test_model.gguf") + + assert generator.model_path == "test_model.gguf" + assert generator.n_ctx == 0 + assert generator.n_batch == 512 + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 0, "n_batch": 512} + assert generator.generation_kwargs == {} + + def test_custom_init(self): + """ + Test custom initialization parameters. + """ + generator = LlamaCppChatGenerator( + model="test_model.gguf", + n_ctx=8192, + n_batch=512, + ) + + assert generator.model_path == "test_model.gguf" + assert generator.n_ctx == 8192 + assert generator.n_batch == 512 + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 8192, "n_batch": 512} + assert generator.generation_kwargs == {} + + def test_ignores_model_path_if_specified_in_model_kwargs(self): + """ + Test that model_path is ignored if already specified in model_kwargs. + """ + generator = LlamaCppChatGenerator( + model="test_model.gguf", + n_ctx=8192, + n_batch=512, + model_kwargs={"model_path": "other_model.gguf"}, + ) + assert generator.model_kwargs["model_path"] == "other_model.gguf" + + def test_ignores_n_ctx_if_specified_in_model_kwargs(self): + """ + Test that n_ctx is ignored if already specified in model_kwargs. + """ + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 8192}) + assert generator.model_kwargs["n_ctx"] == 8192 + + def test_ignores_n_batch_if_specified_in_model_kwargs(self): + """ + Test that n_batch is ignored if already specified in model_kwargs. + """ + generator = LlamaCppChatGenerator( + model="test_model.gguf", n_ctx=8192, n_batch=512, model_kwargs={"n_batch": 1024} + ) + assert generator.model_kwargs["n_batch"] == 1024 + + def test_raises_error_without_warm_up(self): + """ + Test that the generator raises an error if warm_up() is not called before running. + """ + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512) + with pytest.raises(RuntimeError): + generator.run("What is the capital of China?") + + def test_run_with_empty_message(self, generator_mock): + """ + Test that an empty message returns an empty list of replies. + """ + generator, _ = generator_mock + result = generator.run([]) + assert isinstance(result["replies"], list) + assert len(result["replies"]) == 0 + + def test_run_with_valid_message(self, generator_mock): + """ + Test that a valid message returns a list of replies. + """ + generator, mock_model = generator_mock + mock_output = { + "id": "unique-id-123", + "model": "Test Model Path", + "created": 1715226164, + "choices": [ + {"index": 0, "message": {"content": "Generated text", "role": "assistant"}, "finish_reason": "stop"} + ], + "usage": {"prompt_tokens": 14, "completion_tokens": 57, "total_tokens": 71}, + } + mock_model.create_chat_completion.return_value = mock_output + result = generator.run(messages=[ChatMessage.from_system("Test")]) + assert isinstance(result["replies"], list) + assert len(result["replies"]) == 1 + assert isinstance(result["replies"][0], ChatMessage) + assert result["replies"][0].content == "Generated text" + assert result["replies"][0].role == ChatRole.ASSISTANT + + def test_run_with_generation_kwargs(self, generator_mock): + """ + Test that a valid message and generation kwargs returns a list of replies. + """ + generator, mock_model = generator_mock + mock_output = { + "id": "unique-id-123", + "model": "Test Model Path", + "created": 1715226164, + "choices": [ + {"index": 0, "message": {"content": "Generated text", "role": "assistant"}, "finish_reason": "length"} + ], + "usage": {"prompt_tokens": 14, "completion_tokens": 57, "total_tokens": 71}, + } + mock_model.create_chat_completion.return_value = mock_output + generation_kwargs = {"max_tokens": 128} + result = generator.run([ChatMessage.from_system("Write a 200 word paragraph.")], generation_kwargs) + assert result["replies"][0].content == "Generated text" + assert result["replies"][0].meta["finish_reason"] == "length" + + @pytest.mark.integration + def test_run(self, generator): + """ + Test that a valid message returns a list of replies. + """ + questions_and_answers = [ + ("What's the capital of France?", "Paris"), + ("What is the capital of Canada?", "Ottawa"), + ("What is the capital of Ghana?", "Accra"), + ] + + for question, answer in questions_and_answers: + chat_message = ChatMessage.from_system( + f"GPT4 Correct User: Answer in a single word. {question} <|end_of_turn|>\n GPT4 Correct Assistant:" + ) + result = generator.run([chat_message]) + + assert "replies" in result + assert isinstance(result["replies"], list) + assert len(result["replies"]) > 0 + assert any(answer.lower() in reply.content.lower() for reply in result["replies"]) + + @pytest.mark.integration + def test_run_rag_pipeline(self, generator): + """ + Test that a valid message returns a list of replies. + """ + document_store = InMemoryDocumentStore() + documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="""Elephants have been observed to behave in a way that indicates a high + level of self-awareness, such as recognizing themselves in mirrors.""" + ), + Document( + content="""In certain parts of the world, like the Maldives, Puerto Rico, + and San Diego, you can witness the phenomenon of bioluminescent waves.""" + ), + ] + document_store.write_documents(documents=documents) + + pipeline = Pipeline() + pipeline.add_component( + instance=InMemoryBM25Retriever(document_store=document_store, top_k=1), + name="retriever", + ) + pipeline.add_component(instance=ChatPromptBuilder(variables=["query", "documents"]), name="prompt_builder") + pipeline.add_component(instance=generator, name="llm") + pipeline.connect("retriever.documents", "prompt_builder.documents") + pipeline.connect("prompt_builder.prompt", "llm.messages") + + question = "How many languages are there?" + location = "Puerto Rico" + system_message = ChatMessage.from_system( + "You are a helpful assistant giving out valuable information to tourists." + ) + messages = [ + system_message, + ChatMessage.from_user( + """ + Given these documents and given that I am currently in {{ location }}, answer the question.\nDocuments: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + \nQuestion: {{query}} + \nAnswer: + """ + ), + ] + question = "Can I see bioluminescent waves at my current location?" + result = pipeline.run( + data={ + "retriever": {"query": question}, + "prompt_builder": { + "template_variables": {"location": location}, + "template": messages, + "query": question, + }, + } + ) + + replies = result["llm"]["replies"] + assert len(replies) > 0 + assert any("bioluminescent waves" in reply.content for reply in replies) + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + @pytest.mark.integration + def test_json_constraining(self, generator): + """ + Test that the generator can output valid JSON. + """ + messages = [ChatMessage.from_system("Output valid json only. List 2 people with their name and age.")] + json_schema = { + "type": "object", + "properties": { + "people": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + }, + }, + }, + }, + "required": ["people"], + } + + result = generator.run( + messages=messages, + generation_kwargs={ + "response_format": {"type": "json_object", "schema": json_schema}, + }, + ) + + assert "replies" in result + assert isinstance(result["replies"], list) + assert len(result["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in result["replies"]) + for reply in result["replies"]: + assert json.loads(reply.content) + assert isinstance(json.loads(reply.content), dict) + assert "people" in json.loads(reply.content) + assert isinstance(json.loads(reply.content)["people"], list) + assert all(isinstance(person, dict) for person in json.loads(reply.content)["people"]) + assert all("name" in person for person in json.loads(reply.content)["people"]) + assert all("age" in person for person in json.loads(reply.content)["people"]) + assert all(isinstance(person["name"], str) for person in json.loads(reply.content)["people"]) + assert all(isinstance(person["age"], int) for person in json.loads(reply.content)["people"]) + + +class TestLlamaCppChatGeneratorFunctionary: + def get_current_temperature(self, location): + """Get the current temperature in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) + elif "san francisco" in location.lower(): + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + @pytest.fixture + def generator(self, model_path, capsys): + gguf_model_path = ( + "https://huggingface.co/meetkai/functionary-small-v2.4-GGUF/resolve/main/functionary-small-v2.4.Q4_0.gguf" + ) + filename = "functionary-small-v2.4.Q4_0.gguf" + download_file(gguf_model_path, str(model_path / filename), capsys) + model_path = str(model_path / filename) + hf_tokenizer_path = "meetkai/functionary-small-v2.4-GGUF" + generator = LlamaCppChatGenerator( + model=model_path, + n_ctx=8192, + n_batch=512, + model_kwargs={ + "chat_format": "functionary-v2", + "hf_tokenizer_path": hf_tokenizer_path, + }, + ) + generator.warm_up() + return generator + + @pytest.mark.integration + def test_function_call(self, generator): + tools = [ + { + "type": "function", + "function": { + "name": "get_user_info", + "parameters": { + "type": "object", + "properties": { + "username": {"type": "string", "description": "The username to retrieve information for."} + }, + "required": ["username"], + }, + "description": "Retrieves detailed information about a user.", + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "get_user_info"}} + + messages = [ + ChatMessage.from_user("Get information for user john_doe"), + ] + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) + + assert "tool_calls" in response["replies"][0].meta + tool_calls = response["replies"][0].meta["tool_calls"] + assert len(tool_calls) > 0 + assert tool_calls[0]["function"]["name"] == "get_user_info" + assert "username" in json.loads(tool_calls[0]["function"]["arguments"]) + assert response["replies"][0].role == ChatRole.ASSISTANT + + def test_function_call_and_execute(self, generator): + messages = [ChatMessage.from_user("What's the weather like in San Francisco?")] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Get the current temperature in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + response = generator.run(messages=messages, generation_kwargs={"tools": tools}) + + available_functions = { + "get_current_temperature": self.get_current_temperature, + } + + assert "replies" in response + assert len(response["replies"]) > 0 + + first_reply = response["replies"][0] + assert "tool_calls" in first_reply.meta + tool_calls = first_reply.meta["tool_calls"] + + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + assert function_name in available_functions + function_response = available_functions[function_name](**function_args) + function_message = ChatMessage.from_function(function_response, function_name) + messages.append(function_message) + + second_response = generator.run(messages=messages) + assert "replies" in second_response + assert len(second_response["replies"]) > 0 + assert any("San Francisco" in reply.content for reply in second_response["replies"]) + assert any("72" in reply.content for reply in second_response["replies"]) + + +class TestLlamaCppChatGeneratorChatML: + + @pytest.fixture + def generator(self, model_path, capsys): + gguf_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) + filename = "openchat-3.5-1210.Q3_K_S.gguf" + download_file(gguf_model_path, str(model_path / filename), capsys) + model_path = str(model_path / filename) + generator = LlamaCppChatGenerator( + model=model_path, + n_ctx=8192, + n_batch=512, + model_kwargs={ + "chat_format": "chatml-function-calling", + }, + ) + generator.warm_up() + return generator + + @pytest.mark.integration + def test_function_call_chatml(self, generator): + messages = [ + ChatMessage.from_system( + """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, + detailed, and polite answers to the user's questions. The assistant calls functions with appropriate + input when necessary""" + ), + ChatMessage.from_user("Extract Jason is 25 years old"), + ] + + tools = [ + { + "type": "function", + "function": { + "name": "UserDetail", + "parameters": { + "type": "object", + "title": "UserDetail", + "properties": { + "name": {"title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + }, + "required": ["name", "age"], + }, + }, + } + ] + + tool_choice = {"type": "function", "function": {"name": "UserDetail"}} + + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) + for reply in response["replies"]: + assert "tool_calls" in reply.meta + tool_calls = reply.meta["tool_calls"] + assert len(tool_calls) > 0 + assert tool_calls[0]["function"]["name"] == "UserDetail" + assert "name" in json.loads(tool_calls[0]["function"]["arguments"]) + assert "age" in json.loads(tool_calls[0]["function"]["arguments"]) + assert "Jason" in json.loads(tool_calls[0]["function"]["arguments"])["name"] + assert 25 == json.loads(tool_calls[0]["function"]["arguments"])["age"] diff --git a/integrations/mistral/examples/streaming_chat_with_rag.py b/integrations/mistral/examples/streaming_chat_with_rag.py index 2e3eeee5a..6c7f015d8 100644 --- a/integrations/mistral/examples/streaming_chat_with_rag.py +++ b/integrations/mistral/examples/streaming_chat_with_rag.py @@ -2,7 +2,7 @@ # This example streams chat replies to the console. from haystack import Pipeline -from haystack.components.builders import DynamicChatPromptBuilder +from haystack.components.builders import ChatPromptBuilder from haystack.components.converters import HTMLToDocument from haystack.components.fetchers import LinkContentFetcher from haystack.components.generators.utils import print_streaming_chunk @@ -39,7 +39,7 @@ text_embedder = MistralTextEmbedder() retriever = InMemoryEmbeddingRetriever(document_store=document_store) -prompt_builder = DynamicChatPromptBuilder(runtime_variables=["documents"]) +prompt_builder = ChatPromptBuilder(variables=["documents"]) llm = MistralChatGenerator(streaming_callback=print_streaming_chunk) messages = [ChatMessage.from_user("Here are some the documents: {{documents}} \\n Answer: {{query}}")] @@ -60,7 +60,7 @@ result = rag_pipeline.run( { "text_embedder": {"text": question}, - "prompt_builder": {"template_variables": {"query": question}, "prompt_source": messages}, + "prompt_builder": {"template_variables": {"query": question}, "template": messages}, "llm": {"generation_kwargs": {"max_tokens": 165}}, } ) diff --git a/integrations/mistral/pydoc/config.yml b/integrations/mistral/pydoc/config.yml index 86ad5f1d0..c26843a54 100644 --- a/integrations/mistral/pydoc/config.yml +++ b/integrations/mistral/pydoc/config.yml @@ -16,7 +16,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Mistral integration for Haystack category_slug: integrations-api title: Mistral diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index 460dbc4cc..bcb4f5999 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -41,12 +41,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/mistral-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -57,7 +59,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -135,12 +137,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] diff --git a/integrations/mongodb_atlas/CHANGELOG.md b/integrations/mongodb_atlas/CHANGELOG.md new file mode 100644 index 000000000..851858355 --- /dev/null +++ b/integrations/mongodb_atlas/CHANGELOG.md @@ -0,0 +1,64 @@ +# Changelog + +## [unreleased] + +### 🚀 Features + +- Defer the database connection to when it's needed (#770) +- Add filter_policy to mongodb_atlas integration (#823) + +### 🐛 Bug Fixes + +- Pass empty dict to filter instead of None (#775) +- `Mongo` - Fallback to default filter policy when deserializing retrievers without the init parameter (#899) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/mongodb_atlas-v0.2.1] - 2024-04-09 + +### 🐛 Bug Fixes + +- Fix haystack-ai pin (#649) + + + +### 📚 Documentation + +- Disable-class-def (#556) + +## [integrations/mongodb_atlas-v0.2.0] - 2024-03-09 + +### 📚 Documentation + +- Mongo atlas (#534) +- Final API docs touches (#538) + +### Mongodb + +- Improve example (#546) + +## [integrations/mongodb_atlas-v0.1.0] - 2024-02-23 + +### 🚀 Features + +- MongoDBAtlas Document Store (#413) +- `MongoDBAtlasEmbeddingRetriever` (#427) + +### 🐛 Bug Fixes + +- Remove filters from `MongoDBAtlasDocumentStore.count()` method (#430) +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Fix pyproject for mongodbatlas (#478) + + + +### 📚 Documentation + +- Update category slug (#442) + + diff --git a/integrations/mongodb_atlas/pydoc/config.yml b/integrations/mongodb_atlas/pydoc/config.yml index 85694c57f..a38b0a449 100644 --- a/integrations/mongodb_atlas/pydoc/config.yml +++ b/integrations/mongodb_atlas/pydoc/config.yml @@ -16,7 +16,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: MongoDB Atlas integration for Haystack category_slug: integrations-api title: MongoDB Atlas diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index df6a1dec6..170f6e94d 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -25,10 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "pymongo[srv]", -] +dependencies = ["haystack-ai", "pymongo[srv]"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -50,49 +45,30 @@ git_describe_command = 'git describe --tags --match="integrations/mongodb_atlas- dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "ipython", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -135,9 +111,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -164,19 +146,10 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pymongo.*", - "pytest.*" -] +module = ["haystack.*", "haystack_integrations.*", "pymongo.*", "pytest.*"] ignore_missing_imports = true diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index ffad97789..91a42e135 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore @@ -43,6 +45,7 @@ def __init__( document_store: MongoDBAtlasDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Create the MongoDBAtlasDocumentStore component. @@ -52,6 +55,7 @@ def __init__( included in the configuration of the `vector_search_index`. The configuration must be done manually in the Web UI of MongoDB Atlas. :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. """ @@ -62,6 +66,9 @@ def __init__( self.document_store = document_store self.filters = filters or {} self.top_k = top_k + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) def to_dict(self) -> Dict[str, Any]: """ @@ -74,6 +81,7 @@ def to_dict(self) -> Dict[str, Any]: self, filters=self.filters, top_k=self.top_k, + filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -90,6 +98,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -103,12 +115,14 @@ def run( Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided embedding similarity. :param query_embedding: Embedding of the query. - :param filters: Filters applied to the retrieved Documents. Overrides the value specified at initialization. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. :returns: A dictionary with the following keys: - `documents`: List of Documents most similar to the given `query_embedding` """ - filters = filters or self.filters + filters = apply_filter_policy(self.filter_policy, self.filters, filters) top_k = top_k or self.top_k docs = self.document_store._embedding_retrieval( diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 4cb5b8659..93eb87005 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -12,6 +12,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.collection import Collection from pymongo.driver_info import DriverInfo from pymongo.errors import BulkWriteError @@ -81,22 +82,33 @@ def __init__( msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' raise ValueError(msg) - resolved_connection_string = mongo_connection_string.resolve_value() self.mongo_connection_string = mongo_connection_string self.database_name = database_name self.collection_name = collection_name self.vector_search_index = vector_search_index + self._connection: Optional[MongoClient] = None + self._collection: Optional[Collection] = None - self.connection: MongoClient = MongoClient( - resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = self.connection[self.database_name] + @property + def connection(self) -> MongoClient: + if self._connection is None: + self._connection = MongoClient( + self.mongo_connection_string.resolve_value(), driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) - if collection_name not in database.list_collection_names(): - msg = f"Collection '{collection_name}' does not exist in database '{database_name}'." - raise ValueError(msg) - self.collection = database[self.collection_name] + return self._connection + + @property + def collection(self) -> Collection: + if self._collection is None: + database = self.connection[self.database_name] + + if self.collection_name not in database.list_collection_names(): + msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'." + raise ValueError(msg) + self._collection = database[self.collection_name] + return self._collection def to_dict(self) -> Dict[str, Any]: """ @@ -233,7 +245,7 @@ def _embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - filters = _normalize_filters(filters) if filters else None + filters = _normalize_filters(filters) if filters else {} pipeline = [ { diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 89810ec8b..453d9d16c 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import os +from unittest.mock import patch from uuid import uuid4 import pytest @@ -16,13 +17,23 @@ from pymongo.driver_info import DriverInfo +@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient") +def test_init_is_lazy(_mock_client): + MongoDBAtlasDocumentStore( + mongo_connection_string=Secret.from_token("test"), + database_name="database_name", + collection_name="collection_name", + vector_search_index="cosine_index", + ) + _mock_client.assert_not_called() + + @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): - @pytest.fixture def document_store(self): database_name = "haystack_integration_test" diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 4ef5222ce..56eec928f 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -5,6 +5,7 @@ import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore @@ -31,6 +32,13 @@ def test_init_default(self): assert retriever.document_store == mock_store assert retriever.filters == {} assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + + retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store, filter_policy="merge") + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + MongoDBAtlasEmbeddingRetriever(document_store=mock_store, filter_policy="wrong_policy") def test_init(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) @@ -42,6 +50,20 @@ def test_init(self): assert retriever.document_store == mock_store assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_init_filter_policy_merge(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasEmbeddingRetriever( + document_store=mock_store, + filters={"field": "value"}, + top_k=5, + filter_policy=FilterPolicy.MERGE, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.MERGE def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") @@ -72,6 +94,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a }, "filters": {"field": "value"}, "top_k": 5, + "filter_policy": "replace", }, } @@ -96,6 +119,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client }, "filters": {"field": "value"}, "top_k": 5, + "filter_policy": "replace", }, } @@ -109,6 +133,43 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client assert document_store.vector_search_index == "cosine_index" assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + data = { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_embeddings_collection", + "vector_search_index": "cosine_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = MongoDBAtlasEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_embeddings_collection" + assert document_store.vector_search_index == "cosine_index" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE def test_run(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) @@ -121,3 +182,19 @@ def test_run(self): mock_store._embedding_retrieval.assert_called_once_with(query_embedding=[0.3, 0.5], filters={}, top_k=10) assert res == {"documents": [doc]} + + def test_run_merge_policy_filter(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._embedding_retrieval.return_value = [doc] + + retriever = MongoDBAtlasEmbeddingRetriever( + document_store=mock_store, filters={"foo": "boo"}, filter_policy=FilterPolicy.MERGE + ) + res = retriever.run(query_embedding=[0.3, 0.5], filters={"field": "value"}) + + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={"field": "value", "foo": "boo"}, top_k=10 + ) + + assert res == {"documents": [doc]} diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md new file mode 100644 index 000000000..a00c913d2 --- /dev/null +++ b/integrations/nvidia/CHANGELOG.md @@ -0,0 +1,42 @@ +# Changelog + +## [unreleased] + +### 🚜 Refactor + +- Remove deprecated Nvidia Cloud Functions backend and related code. (#803) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/nvidia-v0.0.3] - 2024-05-22 + +### 📚 Documentation + +- Update docstrings of Nvidia integrations (#599) + +### ⚙️ Miscellaneous Tasks + +- Add generate docs to Nvidia workflow (#603) + +## [integrations/nvidia-v0.0.2] - 2024-03-18 + +### 📚 Documentation + +- Disable-class-def (#556) + +## [integrations/nvidia-v0.0.1] - 2024-03-07 + +### 🚀 Features + +- Add `NvidiaTextEmbedder`, `NvidiaDocumentEmbedder` and co. (#537) + +### 🐛 Bug Fixes + +- `nvidia-haystack`- Handle non-strict env var secrets correctly (#543) + +## [integrations/nvidia-v0.0.0] - 2024-03-01 + + diff --git a/integrations/nvidia/pydoc/config.yml b/integrations/nvidia/pydoc/config.yml index 7e9811d25..80bb212c5 100644 --- a/integrations/nvidia/pydoc/config.yml +++ b/integrations/nvidia/pydoc/config.yml @@ -17,12 +17,12 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Nvidia integration for Haystack category_slug: integrations-api title: Nvidia slug: integrations-nvidia - order: 50 + order: 165 markdown: descriptive_class_title: false classdef_code_block: false diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index 753f4f938..af1d806ef 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -42,12 +42,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/nvidia-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", "requests_mock"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -58,7 +60,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -117,7 +119,6 @@ unfixable = [ # Don't touch unused imports "F401", ] -extend-exclude = ["tests", "example"] [tool.ruff.isort] known-first-party = ["src"] @@ -148,6 +149,8 @@ module = [ "haystack_integrations.*", "pytest.*", "numpy.*", + "requests_mock.*", + "pydantic.*" ] ignore_missing_imports = true diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py index 588aca2e6..bc2d9372c 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py @@ -1,7 +1,5 @@ from .document_embedder import NvidiaDocumentEmbedder from .text_embedder import NvidiaTextEmbedder +from .truncate import EmbeddingTruncateMode -__all__ = [ - "NvidiaDocumentEmbedder", - "NvidiaTextEmbedder", -] +__all__ = ["NvidiaDocumentEmbedder", "NvidiaTextEmbedder", "EmbeddingTruncateMode"] diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py deleted file mode 100644 index 27e0dbeac..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nim_backend.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import requests - -from .backend import EmbedderBackend - -REQUEST_TIMEOUT = 60 - - -class NimBackend(EmbedderBackend): - def __init__( - self, - model: str, - api_url: str, - model_kwargs: Optional[Dict[str, Any]] = None, - ): - headers = { - "Content-Type": "application/json", - "accept": "application/json", - } - self.session = requests.Session() - self.session.headers.update(headers) - - self.model = model - self.api_url = api_url - self.model_kwargs = model_kwargs or {} - - def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: - url = f"{self.api_url}/embeddings" - - res = self.session.post( - url, - json={ - "model": self.model, - "input": texts, - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() - - data = res.json() - # Sort the embeddings by index, we don't know whether they're out of order or not - embeddings = [e["embedding"] for e in sorted(data["data"], key=lambda e: e["index"])] - - return embeddings, {"usage": data["usage"]} diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nvcf_backend.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nvcf_backend.py deleted file mode 100644 index 7d4b07dca..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_nvcf_backend.py +++ /dev/null @@ -1,109 +0,0 @@ -from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Literal, Optional, Tuple, Union - -from haystack.utils.auth import Secret -from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient - -from .backend import EmbedderBackend - -MAX_INPUT_STRING_LENGTH = 2048 -MAX_INPUTS = 50 - - -class NvcfBackend(EmbedderBackend): - def __init__( - self, - model: str, - api_key: Secret, - model_kwargs: Optional[Dict[str, Any]] = None, - ): - if not model.startswith("playground_"): - model = f"playground_{model}" - - super().__init__(model=model, model_kwargs=model_kwargs) - - self.api_key = api_key - self.client = NvidiaCloudFunctionsClient( - api_key=api_key, - headers={ - "Content-Type": "application/json", - "Accept": "application/json", - }, - ) - self.nvcf_id = self.client.get_model_nvcf_id(self.model_name) - - def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: - request = EmbeddingsRequest(input=texts, **self.model_kwargs).to_dict() - json_response = self.client.query_function(self.nvcf_id, request) - response = EmbeddingsResponse.from_dict(json_response) - - # Sort resulting embeddings by index - assert all(isinstance(r.embedding, list) for r in response.data) - sorted_embeddings: List[List[float]] = [r.embedding for r in sorted(response.data, key=lambda e: e.index)] # type: ignore - metadata = {"usage": response.usage.to_dict()} - return sorted_embeddings, metadata - - -@dataclass -class EmbeddingsRequest: - input: Union[str, List[str]] - model: Literal["query", "passage"] - encoding_format: Literal["float", "base64"] = "float" - - def __post_init__(self): - if isinstance(self.input, list): - if len(self.input) > MAX_INPUTS: - msg = f"The number of inputs should not exceed {MAX_INPUTS}" - raise ValueError(msg) - else: - self.input = [self.input] - - if len(self.input) == 0: - msg = "The number of inputs should not be 0" - raise ValueError(msg) - - if any(len(x) > MAX_INPUT_STRING_LENGTH for x in self.input): - msg = f"The length of each input should not exceed {MAX_INPUT_STRING_LENGTH} characters" - raise ValueError(msg) - - if self.encoding_format not in ["float", "base64"]: - msg = "encoding_format should be either 'float' or 'base64'" - raise ValueError(msg) - - if self.model not in ["query", "passage"]: - msg = "model should be either 'query' or 'passage'" - raise ValueError(msg) - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - -@dataclass -class Usage: - prompt_tokens: int - total_tokens: int - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - -@dataclass -class Embeddings: - index: int - embedding: Union[List[float], str] - - -@dataclass -class EmbeddingsResponse: - data: List[Embeddings] - usage: Usage - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingsResponse": - try: - embeddings = [Embeddings(**x) for x in data["data"]] - usage = Usage(**data["usage"]) - return cls(data=embeddings, usage=usage) - except (KeyError, TypeError) as e: - msg = f"Failed to parse EmbeddingsResponse from data: {data}" - raise ValueError(msg) from e diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/backend.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/backend.py deleted file mode 100644 index 09e9b7c80..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/backend.py +++ /dev/null @@ -1,29 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple - - -class EmbedderBackend(ABC): - def __init__(self, model: str, model_kwargs: Optional[Dict[str, Any]] = None): - """ - Initialize the backend. - - :param model: - The name of the model to use. - :param model_kwargs: - Additional keyword arguments to pass to the model. - """ - self.model_name = model - self.model_kwargs = model_kwargs or {} - - @abstractmethod - def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: - """ - Invoke the backend and embed the given texts. - - :param texts: - Texts to embed. - :return: - Vector representation of the texts and - metadata returned by the service. - """ - pass diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index da181bd22..f5d1747b8 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -1,20 +1,21 @@ -from typing import Any, Dict, List, Optional, Tuple +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation from tqdm import tqdm -from ._nim_backend import NimBackend -from ._nvcf_backend import NvcfBackend -from .backend import EmbedderBackend +from .truncate import EmbeddingTruncateMode + +_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @component class NvidiaDocumentEmbedder: """ A component for embedding documents using embedding models provided by - [NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) - and NVIDIA Inference Microservices. + [NVIDIA NIMs](https://ai.nvidia.com). Usage example: ```python @@ -22,7 +23,7 @@ class NvidiaDocumentEmbedder: doc = Document(content="I love pizza!") - text_embedder = NvidiaDocumentEmbedder(model="nvolveqa_40k") + text_embedder = NvidiaDocumentEmbedder(model="NV-Embed-QA", api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia") text_embedder.warm_up() result = document_embedder.run([doc]) @@ -32,25 +33,29 @@ class NvidiaDocumentEmbedder: def __init__( self, - model: str, + model: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), - api_url: Optional[str] = None, + api_url: str = _DEFAULT_API_URL, prefix: str = "", suffix: str = "", batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, ): """ Create a NvidiaTextEmbedder component. :param model: Embedding model to use. + If no specific model along with locally hosted API URL is provided, + the system defaults to the available model found using /models API. :param api_key: - API key for the NVIDIA AI Foundation Endpoints. + API key for the NVIDIA NIM. :param api_url: - Custom API URL for the NVIDIA Inference Microservices. + Custom API URL for the NVIDIA NIM. + Format for API URL is http://host:port :param prefix: A string to add to the beginning of each text. :param suffix: @@ -64,11 +69,14 @@ def __init__( List of meta fields that should be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param truncate: + Specifies how inputs longer that the maximum token length should be truncated. + If None the behavior is model-dependent, see the official documentation for more information. """ self.api_key = api_key self.model = model - self.api_url = api_url + self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"]) self.prefix = prefix self.suffix = suffix self.batch_size = batch_size @@ -76,9 +84,35 @@ def __init__( self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator - self.backend: Optional[EmbedderBackend] = None + if isinstance(truncate, str): + truncate = EmbeddingTruncateMode.from_str(truncate) + self.truncate = truncate + + self.backend: Optional[Any] = None self._initialized = False + if is_hosted(api_url) and not self.model: # manually set default model + self.model = "NV-Embed-QA" + + def default_model(self): + """Set default model in local NIM mode.""" + valid_models = [ + model.id for model in self.backend.models() if not model.base_model or model.base_model == model.id + ] + name = next(iter(valid_models), None) + if name: + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + stacklevel=2, + ) + self.model = self.backend.model = name + else: + error_message = "No locally hosted model was found." + raise ValueError(error_message) + def warm_up(self): """ Initializes the component. @@ -86,17 +120,21 @@ def warm_up(self): if self._initialized: return - if self.api_url is None: - if self.api_key is None: - msg = "API key is required for NVIDIA AI Foundation Endpoints." - raise ValueError(msg) - - self.backend = NvcfBackend(self.model, api_key=self.api_key, model_kwargs={"model": "passage"}) - else: - self.backend = NimBackend(self.model, api_url=self.api_url, model_kwargs={"input_type": "passage"}) + model_kwargs = {"input_type": "passage"} + if self.truncate is not None: + model_kwargs["truncate"] = str(self.truncate) + self.backend = NimBackend( + self.model, + api_url=self.api_url, + api_key=self.api_key, + model_kwargs=model_kwargs, + ) self._initialized = True + if not self.model: + self.default_model() + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -115,6 +153,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + truncate=str(self.truncate) if self.truncate is not None else None, ) @classmethod diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 6af5ba25f..1c4a7c5c9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -1,19 +1,20 @@ -from typing import Any, Dict, List, Optional +import warnings +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from ._nim_backend import NimBackend -from ._nvcf_backend import NvcfBackend -from .backend import EmbedderBackend +from .truncate import EmbeddingTruncateMode + +_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @component class NvidiaTextEmbedder: """ A component for embedding strings using embedding models provided by - [NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) - and NVIDIA Inference Microservices. + [NVIDIA NIMs](https://ai.nvidia.com). For models that differentiate between query and document inputs, this component embeds the input string as a query. @@ -24,7 +25,7 @@ class NvidiaTextEmbedder: text_to_embed = "I love pizza!" - text_embedder = NvidiaTextEmbedder(model="nvolveqa_40k") + text_embedder = NvidiaTextEmbedder(model="NV-Embed-QA", api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia") text_embedder.warm_up() print(text_embedder.run(text_to_embed)) @@ -33,36 +34,69 @@ class NvidiaTextEmbedder: def __init__( self, - model: str, + model: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), - api_url: Optional[str] = None, + api_url: str = _DEFAULT_API_URL, prefix: str = "", suffix: str = "", + truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, ): """ Create a NvidiaTextEmbedder component. :param model: Embedding model to use. + If no specific model along with locally hosted API URL is provided, + the system defaults to the available model found using /models API. :param api_key: - API key for the NVIDIA AI Foundation Endpoints. + API key for the NVIDIA NIM. :param api_url: - Custom API URL for the NVIDIA Inference Microservices. + Custom API URL for the NVIDIA NIM. + Format for API URL is http://host:port :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. + :param truncate: + Specifies how inputs longer that the maximum token length should be truncated. + If None the behavior is model-dependent, see the official documentation for more information. """ self.api_key = api_key self.model = model - self.api_url = api_url + self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"]) self.prefix = prefix self.suffix = suffix - self.backend: Optional[EmbedderBackend] = None + if isinstance(truncate, str): + truncate = EmbeddingTruncateMode.from_str(truncate) + self.truncate = truncate + + self.backend: Optional[Any] = None self._initialized = False + if is_hosted(api_url) and not self.model: # manually set default model + self.model = "NV-Embed-QA" + + def default_model(self): + """Set default model in local NIM mode.""" + valid_models = [ + model.id for model in self.backend.models() if not model.base_model or model.base_model == model.id + ] + name = next(iter(valid_models), None) + if name: + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + stacklevel=2, + ) + self.model = self.backend.model = name + else: + error_message = "No locally hosted model was found." + raise ValueError(error_message) + def warm_up(self): """ Initializes the component. @@ -70,17 +104,21 @@ def warm_up(self): if self._initialized: return - if self.api_url is None: - if self.api_key is None: - msg = "API key is required for NVIDIA AI Foundation Endpoints." - raise ValueError(msg) - - self.backend = NvcfBackend(self.model, api_key=self.api_key, model_kwargs={"model": "query"}) - else: - self.backend = NimBackend(self.model, api_url=self.api_url, model_kwargs={"input_type": "query"}) + model_kwargs = {"input_type": "query"} + if self.truncate is not None: + model_kwargs["truncate"] = str(self.truncate) + self.backend = NimBackend( + self.model, + api_url=self.api_url, + api_key=self.api_key, + model_kwargs=model_kwargs, + ) self._initialized = True + if not self.model: + self.default_model() + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -95,6 +133,7 @@ def to_dict(self) -> Dict[str, Any]: api_url=self.api_url, prefix=self.prefix, suffix=self.suffix, + truncate=str(self.truncate) if self.truncate is not None else None, ) @classmethod diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py new file mode 100644 index 000000000..2c32eabb1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class EmbeddingTruncateMode(Enum): + """ + Specifies how inputs to the NVIDIA embedding components are truncated. + If START, the input will be truncated from the start. + If END, the input will be truncated from the end. + """ + + START = "START" + END = "END" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "EmbeddingTruncateMode": + """ + Create an truncate mode from a string. + + :param string: + String to convert. + :returns: + Truncate mode. + """ + enum_map = {e.value: e for e in EmbeddingTruncateMode} + opt_mode = enum_map.get(string) + if opt_mode is None: + msg = f"Unknown truncate mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return opt_mode diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py deleted file mode 100644 index 499a60b78..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nim_backend.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import requests - -from .backend import GeneratorBackend - -REQUEST_TIMEOUT = 60 - - -class NimBackend(GeneratorBackend): - def __init__( - self, - model: str, - api_url: str, - model_kwargs: Optional[Dict[str, Any]] = None, - ): - headers = { - "Content-Type": "application/json", - "accept": "application/json", - } - self.session = requests.Session() - self.session.headers.update(headers) - - self.model = model - self.api_url = api_url - self.model_kwargs = model_kwargs or {} - - def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: - # We're using the chat completion endpoint as the local containers don't support - # the /completions endpoint. So both the non-chat and chat generator will use this. - url = f"{self.api_url}/chat/completions" - - res = self.session.post( - url, - json={ - "model": self.model, - "messages": [ - { - "role": "user", - "content": prompt, - }, - ], - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() - - completions = res.json() - choices = completions["choices"] - # Sort the choices by index, we don't know whether they're out of order or not - choices.sort(key=lambda c: c["index"]) - replies = [] - meta = [] - for choice in choices: - message = choice["message"] - replies.append(message["content"]) - choice_meta = { - "role": message["role"], - "finish_reason": choice["finish_reason"], - "usage": { - "prompt_tokens": completions["usage"]["prompt_tokens"], - "completion_tokens": completions["usage"]["completion_tokens"], - "total_tokens": completions["usage"]["total_tokens"], - }, - } - meta.append(choice_meta) - - return replies, meta diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nvcf_backend.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nvcf_backend.py deleted file mode 100644 index c0686c132..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_nvcf_backend.py +++ /dev/null @@ -1,117 +0,0 @@ -from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple - -from haystack.utils.auth import Secret -from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient - -from .backend import GeneratorBackend - - -class NvcfBackend(GeneratorBackend): - def __init__( - self, - model: str, - api_key: Secret, - model_kwargs: Optional[Dict[str, Any]] = None, - ): - if not model.startswith("playground_"): - model = f"playground_{model}" - - super().__init__(model=model, model_kwargs=model_kwargs) - - self.api_key = api_key - self.client = NvidiaCloudFunctionsClient( - api_key=api_key, - headers={ - "Content-Type": "application/json", - "Accept": "application/json", - }, - ) - self.nvcf_id = self.client.get_model_nvcf_id(self.model_name) - - def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: - messages = [Message(role="user", content=prompt)] - request = GenerationRequest(messages=messages, **self.model_kwargs).to_dict() - json_response = self.client.query_function(self.nvcf_id, request) - response = GenerationResponse.from_dict(json_response) - - replies = [] - meta = [] - for choice in response.choices: - replies.append(choice.message.content) - meta.append( - { - "role": choice.message.role, - "finish_reason": choice.finish_reason, - "usage": { - "completion_tokens": response.usage.completion_tokens, - "prompt_tokens": response.usage.prompt_tokens, - "total_tokens": response.usage.total_tokens, - }, - } - ) - return replies, meta - - -@dataclass -class Message: - content: str - role: str - - -@dataclass -class GenerationRequest: - messages: List[Message] - temperature: float = 0.2 - top_p: float = 0.7 - max_tokens: int = 1024 - seed: Optional[int] = None - bad: Optional[List[str]] = None - stop: Optional[List[str]] = None - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - -@dataclass -class Choice: - index: int - message: Message - finish_reason: str - - -@dataclass -class Usage: - completion_tokens: int - prompt_tokens: int - total_tokens: int - - -@dataclass -class GenerationResponse: - id: str - choices: List[Choice] - usage: Usage - - @classmethod - def from_dict(cls, data: dict) -> "GenerationResponse": - try: - return cls( - id=data["id"], - choices=[ - Choice( - index=choice["index"], - message=Message(content=choice["message"]["content"], role=choice["message"]["role"]), - finish_reason=choice["finish_reason"], - ) - for choice in data["choices"] - ], - usage=Usage( - completion_tokens=data["usage"]["completion_tokens"], - prompt_tokens=data["usage"]["prompt_tokens"], - total_tokens=data["usage"]["total_tokens"], - ), - ) - except (KeyError, TypeError) as e: - msg = f"Failed to parse {cls.__name__} from data: {data}" - raise ValueError(msg) from e diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py deleted file mode 100644 index 4e19d05ac..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: 2024-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional - - -@dataclass -class Message: - content: str - role: str - - -@dataclass -class GenerationRequest: - messages: List[Message] - temperature: float = 0.2 - top_p: float = 0.7 - max_tokens: int = 1024 - seed: Optional[int] = None - bad: Optional[List[str]] = None - stop: Optional[List[str]] = None - - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - - -@dataclass -class Choice: - index: int - message: Message - finish_reason: str - - -@dataclass -class Usage: - completion_tokens: int - prompt_tokens: int - total_tokens: int - - -@dataclass -class GenerationResponse: - id: str - choices: List[Choice] - usage: Usage - - @classmethod - def from_dict(cls, data: dict) -> "GenerationResponse": - try: - return cls( - id=data["id"], - choices=[ - Choice( - index=choice["index"], - message=Message(content=choice["message"]["content"], role=choice["message"]["role"]), - finish_reason=choice["finish_reason"], - ) - for choice in data["choices"] - ], - usage=Usage( - completion_tokens=data["usage"]["completion_tokens"], - prompt_tokens=data["usage"]["prompt_tokens"], - total_tokens=data["usage"]["total_tokens"], - ), - ) - except (KeyError, TypeError) as e: - msg = f"Failed to parse {cls.__name__} from data: {data}" - raise ValueError(msg) from e diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/backend.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/backend.py deleted file mode 100644 index d14199daf..000000000 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/backend.py +++ /dev/null @@ -1,29 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple - - -class GeneratorBackend(ABC): - def __init__(self, model: str, model_kwargs: Optional[Dict[str, Any]] = None): - """ - Initialize the backend. - - :param model: - The name of the model to use. - :param model_kwargs: - Additional keyword arguments to pass to the model. - """ - self.model_name = model - self.model_kwargs = model_kwargs or {} - - @abstractmethod - def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: - """ - Invoke the backend and prompt the model. - - :param prompt: - Prompt text. - :return: - Vector representation of the generated texts related - metadata returned by the service. - """ - pass diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index b6db399e6..a286400ab 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -1,29 +1,29 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import warnings from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from ._nim_backend import NimBackend -from ._nvcf_backend import NvcfBackend -from .backend import GeneratorBackend +_DEFAULT_API_URL = "https://integrate.api.nvidia.com/v1" @component class NvidiaGenerator: """ - A component for generating text using generative models provided by - [NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) - and NVIDIA Inference Microservices. + Generates text using generative models hosted with + [NVIDIA NIM](https://ai.nvidia.com) on on the [NVIDIA API Catalog](https://build.nvidia.com/explore/discover). + + ### Usage example - Usage example: ```python from haystack_integrations.components.generators.nvidia import NvidiaGenerator generator = NvidiaGenerator( - model="nv_llama2_rlhf_70b", + model="meta/llama3-70b-instruct", model_arguments={ "temperature": 0.2, "top_p": 0.7, @@ -37,12 +37,14 @@ class NvidiaGenerator: print(result["meta"]) print(result["usage"]) ``` + + You need an NVIDIA API key for this component to work. """ def __init__( self, - model: str, - api_url: Optional[str] = None, + model: Optional[str] = None, + api_url: str = _DEFAULT_API_URL, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_arguments: Optional[Dict[str, Any]] = None, ): @@ -51,23 +53,49 @@ def __init__( :param model: Name of the model to use for text generation. - See the [Nvidia catalog](https://catalog.ngc.nvidia.com/ai-foundation-models) + See the [NVIDIA NIMs](https://ai.nvidia.com) for more information on the supported models. + `Note`: If no specific model along with locally hosted API URL is provided, + the system defaults to the available model found using /models API. + Check supported models at [NVIDIA NIM](https://ai.nvidia.com). :param api_key: - API key for the NVIDIA AI Foundation Endpoints. + API key for the NVIDIA NIM. Set it as the `NVIDIA_API_KEY` environment + variable or pass it here. :param api_url: - Custom API URL for the NVIDIA Inference Microservices. + Custom API URL for the NVIDIA NIM. :param model_arguments: - Additional arguments to pass to the model provider. Different models accept different arguments. - Search your model in the [Nvidia catalog](https://catalog.ngc.nvidia.com/ai-foundation-models) - to know the supported arguments. + Additional arguments to pass to the model provider. These arguments are + specific to a model. + Search your model in the [NVIDIA NIM](https://ai.nvidia.com) + to find the arguments it accepts. """ self._model = model - self._api_url = api_url + self._api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/chat/completions"]) self._api_key = api_key self._model_arguments = model_arguments or {} - self._backend: Optional[GeneratorBackend] = None + self._backend: Optional[Any] = None + + self.is_hosted = is_hosted(api_url) + + def default_model(self): + """Set default model in local NIM mode.""" + valid_models = [ + model.id for model in self._backend.models() if not model.base_model or model.base_model == model.id + ] + name = next(iter(valid_models), None) + if name: + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + stacklevel=2, + ) + self._model = self._backend.model_name = name + else: + error_message = "No locally hosted model was found." + raise ValueError(error_message) def warm_up(self): """ @@ -76,17 +104,18 @@ def warm_up(self): if self._backend is not None: return - if self._api_url is None: - if self._api_key is None: - msg = "API key is required for NVIDIA AI Foundation Endpoints." - raise ValueError(msg) - self._backend = NvcfBackend(self._model, api_key=self._api_key, model_kwargs=self._model_arguments) - else: - self._backend = NimBackend( - self._model, - api_url=self._api_url, - model_kwargs=self._model_arguments, - ) + if self._api_url == _DEFAULT_API_URL and self._api_key is None: + msg = "API key is required for hosted NVIDIA NIMs." + raise ValueError(msg) + self._backend = NimBackend( + self._model, + api_url=self._api_url, + api_key=self._api_key, + model_kwargs=self._model_arguments, + ) + + if not self.is_hosted and not self._model: + self.default_model() def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index b8015cfda..da301d29d 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -1,3 +1,4 @@ -from .client import NvidiaCloudFunctionsClient +from .nim_backend import Model, NimBackend +from .utils import is_hosted, url_validation -__all__ = ["NvidiaCloudFunctionsClient"] +__all__ = ["NimBackend", "Model", "is_hosted", "url_validation"] diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py deleted file mode 100644 index b486f05b3..000000000 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py +++ /dev/null @@ -1,82 +0,0 @@ -import copy -from dataclasses import dataclass -from typing import Dict, Optional - -import requests -from haystack.utils import Secret - -FUNCTIONS_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/functions" -INVOKE_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions" -STATUS_ENDPOINT = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status" - -ACCEPTED_STATUS_CODE = 202 - - -@dataclass -class AvailableNvidiaCloudFunctions: - name: str - id: str - status: Optional[str] = None - - -class NvidiaCloudFunctionsClient: - def __init__(self, *, api_key: Secret, headers: Dict[str, str], timeout: int = 60): - self.api_key = api_key.resolve_value() - if self.api_key is None: - msg = "Nvidia Cloud Functions API key is not set." - raise ValueError(msg) - - self.fetch_url_format = STATUS_ENDPOINT - self.headers = copy.deepcopy(headers) - self.headers.update( - { - "Authorization": f"Bearer {self.api_key}", - } - ) - self.timeout = timeout - self.session = requests.Session() - - def query_function(self, func_id: str, payload: Dict[str, str]) -> Dict[str, str]: - invoke_url = f"{INVOKE_ENDPOINT}/{func_id}" - - response = self.session.post(invoke_url, headers=self.headers, json=payload, timeout=self.timeout) - request_id = response.headers.get("NVCF-REQID") - if request_id is None: - msg = "NVCF-REQID header not found in response" - raise ValueError(msg) - - while response.status_code == ACCEPTED_STATUS_CODE: - fetch_url = f"{self.fetch_url_format}/{request_id}" - response = self.session.get(fetch_url, headers=self.headers, timeout=self.timeout) - - response.raise_for_status() - return response.json() - - def available_functions(self) -> Dict[str, AvailableNvidiaCloudFunctions]: - response = self.session.get(FUNCTIONS_ENDPOINT, headers=self.headers, timeout=self.timeout) - response.raise_for_status() - - return { - f["name"]: AvailableNvidiaCloudFunctions( - name=f["name"], - id=f["id"], - status=f.get("status"), - ) - for f in response.json()["functions"] - } - - def get_model_nvcf_id(self, model: str) -> str: - """ - Returns the Nvidia Cloud Functions UUID for the given model. - """ - - available_functions = self.available_functions() - func = available_functions.get(model) - if func is None: - msg = f"Model '{model}' was not found on the Nvidia Cloud Functions backend" - raise ValueError(msg) - elif func.status != "ACTIVE": - msg = f"Model '{model}' is not currently active/usable on the Nvidia Cloud Functions backend" - raise ValueError(msg) - - return func.id diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py new file mode 100644 index 000000000..f69862f0e --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import requests +from haystack.utils import Secret + +REQUEST_TIMEOUT = 60 + + +@dataclass +class Model: + """ + Model information. + + id: unique identifier for the model, passed as model parameter for requests + aliases: list of aliases for the model + base_model: root model for the model + All aliases are deprecated and will trigger a warning when used. + """ + + id: str + aliases: Optional[List[str]] = field(default_factory=list) + base_model: Optional[str] = None + + +class NimBackend: + def __init__( + self, + model: str, + api_url: str, + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), + model_kwargs: Optional[Dict[str, Any]] = None, + ): + headers = { + "Content-Type": "application/json", + "accept": "application/json", + } + + if api_key: + headers["authorization"] = f"Bearer {api_key.resolve_value()}" + + self.session = requests.Session() + self.session.headers.update(headers) + + self.model = model + self.api_url = api_url + self.model_kwargs = model_kwargs or {} + + def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: + url = f"{self.api_url}/embeddings" + + res = self.session.post( + url, + json={ + "model": self.model, + "input": texts, + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + + data = res.json() + # Sort the embeddings by index, we don't know whether they're out of order or not + embeddings = [e["embedding"] for e in sorted(data["data"], key=lambda e: e["index"])] + + return embeddings, {"usage": data["usage"]} + + def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: + # We're using the chat completion endpoint as the NIM API doesn't support + # the /completions endpoint. So both the non-chat and chat generator will use this. + # This is the same for local containers and the cloud API. + url = f"{self.api_url}/chat/completions" + + res = self.session.post( + url, + json={ + "model": self.model, + "messages": [ + { + "role": "user", + "content": prompt, + }, + ], + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + + completions = res.json() + choices = completions["choices"] + # Sort the choices by index, we don't know whether they're out of order or not + choices.sort(key=lambda c: c["index"]) + replies = [] + meta = [] + for choice in choices: + message = choice["message"] + replies.append(message["content"]) + choice_meta = { + "role": message["role"], + "usage": { + "prompt_tokens": completions["usage"]["prompt_tokens"], + "total_tokens": completions["usage"]["total_tokens"], + }, + } + # These fields could be null, the others will always be present + if "finish_reason" in choice: + choice_meta["finish_reason"] = choice["finish_reason"] + if "completion_tokens" in completions["usage"]: + choice_meta["usage"]["completion_tokens"] = completions["usage"]["completion_tokens"] + + meta.append(choice_meta) + + return replies, meta + + def models(self) -> List[Model]: + url = f"{self.api_url}/models" + + res = self.session.get( + url, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + + data = res.json()["data"] + models = [Model(element["id"]) for element in data if "id" in element] + if not models: + msg = f"No hosted model were found at URL '{url}'." + raise ValueError(msg) + return models diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py new file mode 100644 index 000000000..7d4dfc3b4 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -0,0 +1,47 @@ +import warnings +from typing import List +from urllib.parse import urlparse, urlunparse + + +def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: + """ + Validate and normalize an API URL. + + :param api_url: + The API URL to validate and normalize. + :param default_api_url: + The default API URL for comparison. + :param allowed_paths: + A list of allowed base paths that are valid if present in the URL. + :returns: + A normalized version of the API URL with '/v1' path appended, if needed. + :raises ValueError: + If the base URL path is not recognized or does not match expected format. + """ + ## Making sure /v1 in added to the url, followed by infer_path + result = urlparse(api_url) + expected_format = "Expected format is 'http://host:port'." + + if api_url == default_api_url: + return api_url + if result.path: + normalized_path = result.path.strip("/") + if normalized_path == "v1": + pass + elif normalized_path in allowed_paths: + warn_msg = f"{expected_format} Rest is ignored." + warnings.warn(warn_msg, stacklevel=2) + else: + err_msg = f"Base URL path is not recognized. {expected_format}" + raise ValueError(err_msg) + + base_url = urlunparse((result.scheme, result.netloc, "v1", "", "", "")) + return base_url + + +def is_hosted(api_url: str): + """""" + return urlparse(api_url).netloc in [ + "integrate.api.nvidia.com", + "ai.api.nvidia.com", + ] diff --git a/integrations/nvidia/tests/__init__.py b/integrations/nvidia/tests/__init__.py index e873bc332..47611e0b9 100644 --- a/integrations/nvidia/tests/__init__.py +++ b/integrations/nvidia/tests/__init__.py @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .conftest import MockBackend + +__all__ = ["MockBackend"] diff --git a/integrations/nvidia/tests/conftest.py b/integrations/nvidia/tests/conftest.py new file mode 100644 index 000000000..794c994ff --- /dev/null +++ b/integrations/nvidia/tests/conftest.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, List, Optional, Tuple + +import pytest +from haystack.utils import Secret +from haystack_integrations.utils.nvidia import Model, NimBackend +from requests_mock import Mocker + + +class MockBackend(NimBackend): + def __init__(self, model: str, api_key: Optional[Secret] = None, model_kwargs: Optional[Dict[str, Any]] = None): + api_key = api_key or Secret.from_env_var("NVIDIA_API_KEY") + super().__init__(model, api_url="", api_key=api_key, model_kwargs=model_kwargs or {}) + + def embed(self, texts): + inputs = texts + data = [[0.1, 0.2, 0.3] for i in range(len(inputs))] + return data, {"usage": {"total_tokens": 4, "prompt_tokens": 4}} + + def models(self): + return [Model(id="aa")] + + def generate(self) -> Tuple[List[str], List[Dict[str, Any]]]: + return ( + ["This is a mocked response."], + [{"role": "assistant", "usage": {"prompt_tokens": 5, "total_tokens": 10, "completion_tokens": 5}}], + ) + + +@pytest.fixture +def mock_local_models(requests_mock: Mocker) -> None: + requests_mock.get( + "http://localhost:8080/v1/models", + json={ + "data": [ + { + "id": "model1", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": "model1", + }, + ] + }, + ) diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py new file mode 100644 index 000000000..072807685 --- /dev/null +++ b/integrations/nvidia/tests/test_base_url.py @@ -0,0 +1,64 @@ +import pytest +from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder +from haystack_integrations.components.generators.nvidia import NvidiaGenerator + + +@pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8888/embeddings", + "http://0.0.0.0:8888/rankings", + "http://0.0.0.0:8888/v1/rankings", + "http://localhost:8888/chat/completions", + "http://localhost:8888/v1/chat/completions", + ], +) +@pytest.mark.parametrize( + "embedder", + [NvidiaDocumentEmbedder, NvidiaTextEmbedder], +) +def test_base_url_invalid_not_hosted(base_url: str, embedder) -> None: + with pytest.raises(ValueError): + embedder(api_url=base_url, model="x") + + +@pytest.mark.parametrize( + "base_url", + ["http://localhost:8080/v1/embeddings", "http://0.0.0.0:8888/v1/embeddings"], +) +@pytest.mark.parametrize( + "embedder", + [NvidiaDocumentEmbedder, NvidiaTextEmbedder], +) +def test_base_url_valid_embedder(base_url: str, embedder) -> None: + with pytest.warns(UserWarning): + embedder(api_url=base_url) + + +@pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8080/v1/chat/completions", + "http://0.0.0.0:8888/v1/chat/completions", + ], +) +def test_base_url_valid_generator(base_url: str) -> None: + with pytest.warns(UserWarning): + NvidiaGenerator( + api_url=base_url, + model="mistralai/mixtral-8x7b-instruct-v0.1", + ) + + +@pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8888/embeddings", + "http://0.0.0.0:8888/rankings", + "http://0.0.0.0:8888/v1/rankings", + "http://localhost:8888/chat/completions", + ], +) +def test_base_url_invalid_generator(base_url: str) -> None: + with pytest.raises(ValueError): + NvidiaGenerator(api_url=base_url, model="x") diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index 7ac89d5e2..6562a0ea9 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -1,19 +1,21 @@ import os -from unittest.mock import Mock, patch import pytest from haystack import Document from haystack.utils import Secret -from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder +from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode, NvidiaDocumentEmbedder + +from . import MockBackend class TestNvidiaDocumentEmbedder: def test_init_default(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") - embedder = NvidiaDocumentEmbedder("nvolveqa_40k") + embedder = NvidiaDocumentEmbedder() assert embedder.api_key == Secret.from_env_var("NVIDIA_API_KEY") - assert embedder.model == "nvolveqa_40k" + assert embedder.model == "NV-Embed-QA" + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" assert embedder.prefix == "" assert embedder.suffix == "" assert embedder.batch_size == 32 @@ -22,25 +24,28 @@ def test_init_default(self, monkeypatch): assert embedder.embedding_separator == "\n" def test_init_with_parameters(self): - embedder = NvidiaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), - model="nvolveqa_40k", - prefix="prefix", - suffix="suffix", - batch_size=30, - progress_bar=False, - meta_fields_to_embed=["test_field"], - embedding_separator=" | ", - ) - - assert embedder.api_key == Secret.from_token("fake-api-key") - assert embedder.model == "nvolveqa_40k" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 30 - assert embedder.progress_bar is False - assert embedder.meta_fields_to_embed == ["test_field"] - assert embedder.embedding_separator == " | " + with pytest.raises(ValueError): + embedder = NvidiaDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="nvolveqa_40k", + api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/test", + prefix="prefix", + suffix="suffix", + batch_size=30, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + + assert embedder.api_key == Secret.from_token("fake-api-key") + assert embedder.model == "nvolveqa_40k" + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.batch_size == 30 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("NVIDIA_API_KEY", raising=False) @@ -56,7 +61,7 @@ def test_to_dict(self, monkeypatch): "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": None, + "api_url": "https://ai.api.nvidia.com/v1/retrieval/nvidia", "model": "playground_nvolveqa_40k", "prefix": "", "suffix": "", @@ -64,6 +69,7 @@ def test_to_dict(self, monkeypatch): "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", + "truncate": None, }, } @@ -78,9 +84,28 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + truncate=EmbeddingTruncateMode.END, ) data = component.to_dict() assert data == { + "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "api_url": "https://example.com/v1", + "model": "playground_nvolveqa_40k", + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 10, + "progress_bar": False, + "meta_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + "truncate": "END", + }, + } + + def from_dict(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, @@ -92,8 +117,19 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "progress_bar": False, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", + "truncate": "START", }, } + component = NvidiaDocumentEmbedder.from_dict(data) + assert component.model == "nvolveqa_40k" + assert component.api_url == "https://example.com/v1" + assert component.prefix == "prefix" + assert component.suffix == "suffix" + assert component.batch_size == 32 + assert component.progress_bar + assert component.meta_fields_to_embed == [] + assert component.embedding_separator == "\n" + assert component.truncate == EmbeddingTruncateMode.START def test_prepare_texts_to_embed_w_metadata(self): documents = [ @@ -138,26 +174,20 @@ def test_prepare_texts_to_embed_w_suffix(self): "my_prefix document number 4 my_suffix", ] - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_embed_batch(self, mock_client_class): + def test_embed_batch(self): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") embedder = NvidiaDocumentEmbedder( - "playground_nvolveqa_40k", - api_key=Secret.from_token("fake-api-key"), + model, + api_key=api_key, ) - def mock_query_function(_, payload): - inputs = payload["input"] - data = [{"index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] - return {"data": data, "usage": {"total_tokens": 4, "prompt_tokens": 4}} - - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=mock_query_function, - ) - mock_client_class.return_value = mock_client embedder.warm_up() + embedder.backend = MockBackend( + model=model, + api_key=api_key, + ) embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2) @@ -170,34 +200,64 @@ def mock_query_function(_, payload): assert metadata == {"usage": {"prompt_tokens": 3 * 4, "total_tokens": 3 * 4}} - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run(self, mock_client_class): + @pytest.mark.usefixtures("mock_local_models") + def test_run_default_model(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] + api_key = Secret.from_token("fake-api-key") - model = "playground_nvolveqa_40k" embedder = NvidiaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), - model=model, + api_key=api_key, + model=None, + api_url="http://localhost:8080/v1", prefix="prefix ", suffix=" suffix", meta_fields_to_embed=["topic"], embedding_separator=" | ", ) - def mock_query_function(_, payload): - inputs = payload["input"] - data = [{"index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] - return {"data": data, "usage": {"total_tokens": 4, "prompt_tokens": 4}} + with pytest.warns(UserWarning) as record: + embedder.warm_up() + assert len(record) == 1 + assert "Default model is set as:" in str(record[0].message) + assert embedder.model == "model1" - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=mock_query_function, + embedder.backend = MockBackend(model=embedder.model, api_key=api_key) + + result = embedder.run(documents=docs) + + documents_with_embeddings = result["documents"] + metadata = result["meta"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 3 + assert all(isinstance(x, float) for x in doc.embedding) + assert metadata == {"usage": {"prompt_tokens": 4, "total_tokens": 4}} + + def test_run(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + api_key = Secret.from_token("fake-api-key") + model = "playground_nvolveqa_40k" + embedder = NvidiaDocumentEmbedder( + api_key=api_key, + model=model, + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", ) - mock_client_class.return_value = mock_client + embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) result = embedder.run(documents=docs) @@ -213,15 +273,15 @@ def mock_query_function(_, payload): assert all(isinstance(x, float) for x in doc.embedding) assert metadata == {"usage": {"prompt_tokens": 4, "total_tokens": 4}} - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run_custom_batch_size(self, mock_client_class): + def test_run_custom_batch_size(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] + api_key = Secret.from_token("fake-api-key") model = "playground_nvolveqa_40k" embedder = NvidiaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), + api_key=api_key, model=model, prefix="prefix ", suffix=" suffix", @@ -230,17 +290,8 @@ def test_run_custom_batch_size(self, mock_client_class): batch_size=1, ) - def mock_query_function(_, payload): - inputs = payload["input"] - data = [{"index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] - return {"data": data, "usage": {"total_tokens": 4, "prompt_tokens": 4}} - - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=mock_query_function, - ) - mock_client_class.return_value = mock_client embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) result = embedder.run(documents=docs) @@ -257,21 +308,13 @@ def mock_query_function(_, payload): assert metadata == {"usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run_wrong_input_format(self, mock_client_class): - embedder = NvidiaDocumentEmbedder("playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key")) - - def mock_query_function(_, payload): - inputs = payload["input"] - data = [{"index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] - return {"data": data, "usage": {"total_tokens": 4, "prompt_tokens": 4}} + def test_run_wrong_input_format(self): + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaDocumentEmbedder(model, api_key=api_key) - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=mock_query_function, - ) - mock_client_class.return_value = mock_client embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) string_input = "text" list_integers_input = [1, 2, 3] @@ -282,21 +325,13 @@ def mock_query_function(_, payload): with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run_on_empty_list(self, mock_client_class): - embedder = NvidiaDocumentEmbedder("playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key")) - - def mock_query_function(_, payload): - inputs = payload["input"] - data = [{"index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] - return {"data": data, "usage": {"total_tokens": 4, "prompt_tokens": 4}} + def test_run_on_empty_list(self): + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaDocumentEmbedder(model, api_key=api_key) - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=mock_query_function, - ) - mock_client_class.return_value = mock_client embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) empty_list_input = [] result = embedder.run(documents=empty_list_input) @@ -308,11 +343,21 @@ def mock_query_function(_, payload): not os.environ.get("NVIDIA_API_KEY", None), reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", ) + @pytest.mark.skipif( + not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), + reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " + "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", + ) @pytest.mark.integration - def test_run_integration(self): - embedder = NvidiaDocumentEmbedder("playground_nvolveqa_40k") + def test_run_integration_with_nim_backend(self): + model = os.environ["NVIDIA_NIM_EMBEDDER_MODEL"] + url = os.environ["NVIDIA_NIM_ENDPOINT_URL"] + embedder = NvidiaDocumentEmbedder( + model=model, + api_url=url, + api_key=None, + ) embedder.warm_up() - docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), @@ -328,20 +373,18 @@ def test_run_integration(self): assert isinstance(doc.embedding[0], float) @pytest.mark.skipif( - not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), - reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " - "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", ) @pytest.mark.integration - def test_run_integration_with_nim_backend(self): - model = os.environ["NVIDIA_NIM_EMBEDDER_MODEL"] - url = os.environ["NVIDIA_NIM_ENDPOINT_URL"] + def test_run_integration_with_api_catalog(self): embedder = NvidiaDocumentEmbedder( - model=model, - api_url=url, - api_key=None, + model="NV-Embed-QA", + api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia", + api_key=Secret.from_env_var("NVIDIA_API_KEY"), ) embedder.warm_up() + docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 9a157a9d1..9fff9c2e8 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -2,11 +2,39 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import Mock, patch import pytest from haystack.utils import Secret from haystack_integrations.components.generators.nvidia import NvidiaGenerator +from requests_mock import Mocker + + +@pytest.fixture +def mock_local_chat_completion(requests_mock: Mocker) -> None: + requests_mock.post( + "http://localhost:8080/v1/chat/completions", + json={ + "choices": [ + { + "message": {"content": "Hello!", "role": "system"}, + "usage": {"prompt_tokens": 3, "total_tokens": 5, "completion_tokens": 9}, + "finish_reason": "stop", + "index": 0, + }, + { + "message": {"content": "How are you?", "role": "system"}, + "usage": {"prompt_tokens": 3, "total_tokens": 5, "completion_tokens": 9}, + "finish_reason": "stop", + "index": 1, + }, + ], + "usage": { + "prompt_tokens": 3, + "total_tokens": 5, + "completion_tokens": 9, + }, + }, + ) class TestNvidiaGenerator: @@ -55,7 +83,7 @@ def test_to_dict(self, monkeypatch): assert data == { "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", "init_parameters": { - "api_url": None, + "api_url": "https://integrate.api.nvidia.com/v1", "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, "model": "playground_nemotron_steerlm_8b", "model_arguments": {}, @@ -81,7 +109,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": "https://my.url.com", + "api_url": "https://my.url.com/v1", "model": "playground_nemotron_steerlm_8b", "model_arguments": { "temperature": 0.2, @@ -94,105 +122,65 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - @patch("haystack_integrations.components.generators.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run(self, mock_client_class): + @pytest.mark.skipif( + not os.environ.get("NVIDIA_NIM_GENERATOR_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), + reason="Export an env var called NVIDIA_NIM_GENERATOR_MODEL containing the hosted model name and " + "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", + ) + @pytest.mark.integration + def test_run_integration_with_nim_backend(self): + model = os.environ["NVIDIA_NIM_GENERATOR_MODEL"] + url = os.environ["NVIDIA_NIM_ENDPOINT_URL"] generator = NvidiaGenerator( - model="playground_nemotron_steerlm_8b", - api_key=Secret.from_token("fake-api-key"), + model=model, + api_url=url, + api_key=None, model_arguments={ "temperature": 0.2, - "top_p": 0.7, - "max_tokens": 1024, - "seed": None, - "bad": None, - "stop": None, }, ) - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=Mock( - return_value={ - "id": "some_id", - "choices": [ - { - "index": 0, - "message": {"content": "42", "role": "assistant"}, - "finish_reason": "stop", - } - ], - "usage": {"total_tokens": 21, "prompt_tokens": 19, "completion_tokens": 2}, - } - ), - ) - mock_client_class.return_value = mock_client generator.warm_up() - result = generator.run(prompt="What is the answer?") - mock_client.query_function.assert_called_once_with( - "some_id", - { - "messages": [ - {"content": "What is the answer?", "role": "user"}, - ], - "temperature": 0.2, - "top_p": 0.7, - "max_tokens": 1024, - "seed": None, - "bad": None, - "stop": None, - }, - ) - assert result == { - "replies": ["42"], - "meta": [ - { - "finish_reason": "stop", - "role": "assistant", - "usage": { - "total_tokens": 21, - "prompt_tokens": 19, - "completion_tokens": 2, - }, - }, - ], - } - @pytest.mark.skipif( - not os.environ.get("NVIDIA_API_KEY", None), - reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", - ) + assert result["replies"] + assert result["meta"] + @pytest.mark.integration - def test_run_integration_with_nvcf_backend(self): + @pytest.mark.usefixtures("mock_local_models") + @pytest.mark.usefixtures("mock_local_chat_completion") + def test_run_integration_with_default_model_nim_backend(self): + model = None + url = "http://localhost:8080/v1" generator = NvidiaGenerator( - model="playground_nv_llama2_rlhf_70b", + model=model, + api_url=url, + api_key=None, model_arguments={ "temperature": 0.2, - "top_p": 0.7, - "max_tokens": 1024, - "seed": None, - "bad": None, - "stop": None, }, ) - generator.warm_up() + with pytest.warns(UserWarning) as record: + generator.warm_up() + assert len(record) == 1 + assert "Default model is set as:" in str(record[0].message) + assert generator._model == "model1" + assert not generator.is_hosted + result = generator.run(prompt="What is the answer?") assert result["replies"] assert result["meta"] @pytest.mark.skipif( - not os.environ.get("NVIDIA_NIM_GENERATOR_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), - reason="Export an env var called NVIDIA_NIM_GENERATOR_MODEL containing the hosted model name and " - "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", ) @pytest.mark.integration - def test_run_integration_with_nim_backend(self): - model = os.environ["NVIDIA_NIM_GENERATOR_MODEL"] - url = os.environ["NVIDIA_NIM_ENDPOINT_URL"] + def test_run_integration_with_api_catalog(self): generator = NvidiaGenerator( - model=model, - api_url=url, - api_key=None, + model="meta/llama3-70b-instruct", + api_url="https://integrate.api.nvidia.com/v1", + api_key=Secret.from_env_var("NVIDIA_API_KEY"), model_arguments={ "temperature": 0.2, }, @@ -202,3 +190,27 @@ def test_run_integration_with_nim_backend(self): assert result["replies"] assert result["meta"] + + def test_local_nim_without_key(self) -> None: + generator = NvidiaGenerator( + model="BOGUS", + api_url="http://localhost:8000", + api_key=None, + ) + generator.warm_up() + + def test_hosted_nim_without_key(self): + generator0 = NvidiaGenerator( + model="BOGUS", + api_url="https://integrate.api.nvidia.com/v1", + api_key=None, + ) + with pytest.raises(ValueError): + generator0.warm_up() + + generator1 = NvidiaGenerator( + model="BOGUS", + api_key=None, + ) + with pytest.raises(ValueError): + generator1.warm_up() diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 39ee02206..7c0a7000d 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -1,32 +1,36 @@ import os -from unittest.mock import Mock, patch import pytest from haystack.utils import Secret -from haystack_integrations.components.embedders.nvidia import NvidiaTextEmbedder +from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode, NvidiaTextEmbedder + +from . import MockBackend class TestNvidiaTextEmbedder: def test_init_default(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") - embedder = NvidiaTextEmbedder("nvolveqa_40k") + embedder = NvidiaTextEmbedder() assert embedder.api_key == Secret.from_env_var("NVIDIA_API_KEY") - assert embedder.model == "nvolveqa_40k" + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" assert embedder.prefix == "" assert embedder.suffix == "" def test_init_with_parameters(self): - embedder = NvidiaTextEmbedder( - api_key=Secret.from_token("fake-api-key"), - model="nvolveqa_40k", - prefix="prefix", - suffix="suffix", - ) - assert embedder.api_key == Secret.from_token("fake-api-key") - assert embedder.model == "nvolveqa_40k" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" + with pytest.raises(ValueError): + embedder = NvidiaTextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="nvolveqa_40k", + api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/test", + prefix="prefix", + suffix="suffix", + ) + assert embedder.api_key == Secret.from_token("fake-api-key") + assert embedder.model == "nvolveqa_40k" + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("NVIDIA_API_KEY", raising=False) @@ -42,10 +46,11 @@ def test_to_dict(self, monkeypatch): "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": None, + "api_url": "https://ai.api.nvidia.com/v1/retrieval/nvidia", "model": "nvolveqa_40k", "prefix": "", "suffix": "", + "truncate": None, }, } @@ -53,37 +58,75 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") component = NvidiaTextEmbedder( model="nvolveqa_40k", + api_url="https://example.com", prefix="prefix", suffix="suffix", + truncate=EmbeddingTruncateMode.START, ) data = component.to_dict() assert data == { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": None, + "api_url": "https://example.com/v1", "model": "nvolveqa_40k", "prefix": "prefix", "suffix": "suffix", + "truncate": "START", }, } - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run(self, mock_client_class): - embedder = NvidiaTextEmbedder( - "playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key"), prefix="prefix ", suffix=" suffix" - ) - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=Mock( - return_value={ - "data": [{"index": 0, "embedding": [0.1, 0.2, 0.3]}], - "usage": {"total_tokens": 4, "prompt_tokens": 4}, - } - ), - ) - mock_client_class.return_value = mock_client + def from_dict(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "api_url": "https://example.com", + "model": "nvolveqa_40k", + "prefix": "prefix", + "suffix": "suffix", + "truncate": "START", + }, + } + component = NvidiaTextEmbedder.from_dict(data) + assert component.model == "nvolveqa_40k" + assert component.api_url == "https://example.com/v1" + assert component.prefix == "prefix" + assert component.suffix == "suffix" + assert component.truncate == "START" + + @pytest.mark.usefixtures("mock_local_models") + def test_run_default_model(self): + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaTextEmbedder(api_url="http://localhost:8080/v1", api_key=api_key) + + assert embedder.model is None + + with pytest.warns(UserWarning) as record: + embedder.warm_up() + + assert len(record) == 1 + assert "Default model is set as:" in str(record[0].message) + assert embedder.model == "model1" + + embedder.backend = MockBackend(model=embedder.model, api_key=api_key) + + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 3 + assert all(isinstance(x, float) for x in result["embedding"]) + assert result["meta"] == { + "usage": {"prompt_tokens": 4, "total_tokens": 4}, + } + + def test_run(self): + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaTextEmbedder("playground_nvolveqa_40k", api_key=api_key, prefix="prefix ", suffix=" suffix") + embedder.warm_up() + embedder.backend = MockBackend(model="playground_nvolveqa_40k", api_key=api_key) + result = embedder.run(text="The food was delicious") assert len(result["embedding"]) == 3 @@ -92,20 +135,11 @@ def test_run(self, mock_client_class): "usage": {"prompt_tokens": 4, "total_tokens": 4}, } - @patch("haystack_integrations.components.embedders.nvidia._nvcf_backend.NvidiaCloudFunctionsClient") - def test_run_wrong_input_format(self, mock_client_class): - embedder = NvidiaTextEmbedder("playground_nvolveqa_40k", api_key=Secret.from_token("fake-api-key")) - mock_client = Mock( - get_model_nvcf_id=Mock(return_value="some_id"), - query_function=Mock( - return_value={ - "data": [{"index": 0, "embedding": [0.1, 0.2, 0.3]}], - "usage": {"total_tokens": 4, "prompt_tokens": 4}, - } - ), - ) - mock_client_class.return_value = mock_client + def test_run_wrong_input_format(self): + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaTextEmbedder("playground_nvolveqa_40k", api_key=api_key) embedder.warm_up() + embedder.backend = MockBackend(model="playground_nvolveqa_40k", api_key=api_key) list_integers_input = [1, 2, 3] @@ -113,12 +147,19 @@ def test_run_wrong_input_format(self, mock_client_class): embedder.run(text=list_integers_input) @pytest.mark.skipif( - not os.environ.get("NVIDIA_API_KEY", None), - reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", + not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), + reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " + "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", ) @pytest.mark.integration - def test_run_integration_with_nvcf_backend(self): - embedder = NvidiaTextEmbedder("playground_nvolveqa_40k") + def test_run_integration_with_nim_backend(self): + model = os.environ["NVIDIA_NIM_EMBEDDER_MODEL"] + url = os.environ["NVIDIA_NIM_ENDPOINT_URL"] + embedder = NvidiaTextEmbedder( + model=model, + api_url=url, + api_key=None, + ) embedder.warm_up() result = embedder.run("A transformer is a deep learning architecture") @@ -129,18 +170,15 @@ def test_run_integration_with_nvcf_backend(self): assert "usage" in meta @pytest.mark.skipif( - not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), - reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " - "NVIDIA_NIM_ENDPOINT_URL containing the local URL to call.", + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", ) @pytest.mark.integration - def test_run_integration_with_nim_backend(self): - model = os.environ["NVIDIA_NIM_EMBEDDER_MODEL"] - url = os.environ["NVIDIA_NIM_ENDPOINT_URL"] + def test_run_integration_with_api_catalog(self): embedder = NvidiaTextEmbedder( - model=model, - api_url=url, - api_key=None, + model="NV-Embed-QA", + api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia", + api_key=Secret.from_env_var("NVIDIA_API_KEY"), ) embedder.warm_up() diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md new file mode 100644 index 000000000..6467aa868 --- /dev/null +++ b/integrations/ollama/CHANGELOG.md @@ -0,0 +1,50 @@ +# Changelog + +## [integrations/ollama-v0.0.7] - 2024-05-31 + +### 🚀 Features + +- Add streaming support to OllamaChatGenerator (#757) + +## [integrations/ollama-v0.0.6] - 2024-04-18 + +### 📚 Documentation + +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Update docstrings (#499) + +### Ollama + +- Change testing workflow (#551) +- Add ollama embedder example (#669) + +## [integrations/ollama-v0.0.5] - 2024-02-28 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +### ⚙️ Miscellaneous Tasks + +- Use `serialize_callable` instead of `serialize_callback_handler` in Ollama (#461) + +## [integrations/ollama-v0.0.4] - 2024-02-12 + +### Ollama + +- Generate api docs (#332) + +## [integrations/ollama-v0.0.3] - 2024-01-16 + +## [integrations/ollama-v0.0.1] - 2024-01-03 + + diff --git a/integrations/ollama/pydoc/config.yml b/integrations/ollama/pydoc/config.yml index 4207ea997..e8f2ca6e5 100644 --- a/integrations/ollama/pydoc/config.yml +++ b/integrations/ollama/pydoc/config.yml @@ -17,7 +17,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Ollama integration for Haystack category_slug: integrations-api title: Ollama diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index d98e2a6ca..5187de31f 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -16,7 +16,7 @@ authors = [ { name = "deepset GmbH", email = "info@deepset.ai" }, ] classifiers = [ - "License :: OSI Approved :: Apache Software License", + "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3.8", @@ -49,22 +49,17 @@ git_describe_command = 'git describe --tags --match="integrations/ollama-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] @@ -72,27 +67,13 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -138,9 +119,15 @@ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -164,25 +151,15 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] markers = [ "integration: marks tests as slow (deselect with '-m \"not integration\"')", ] -addopts = [ - "--import-mode=importlib", -] +addopts = ["--import-mode=importlib"] [[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pytest.*" -] +module = ["haystack.*", "haystack_integrations.*", "pytest.*"] ignore_missing_imports = true diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 2abf3066b..a95d8c4fb 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, List, Optional +import json +from typing import Any, Callable, Dict, List, Optional import requests from haystack import component -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ChatMessage, StreamingChunk from requests import Response @@ -38,6 +39,7 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, template: Optional[str] = None, timeout: int = 120, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ :param model: @@ -52,6 +54,9 @@ def __init__( The full prompt template (overrides what is defined in the Ollama Modelfile). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ self.timeout = timeout @@ -59,11 +64,12 @@ def __init__( self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model + self.streaming_callback = streaming_callback def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} - def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]: + def _create_json_payload(self, messages: List[ChatMessage], stream=False, generation_kwargs=None) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service """ @@ -71,7 +77,7 @@ def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=No return { "messages": [self._message_to_dict(message) for message in messages], "model": self.model, - "stream": False, + "stream": stream, "template": self.template, "options": generation_kwargs, } @@ -85,6 +91,41 @@ def _build_message_from_ollama_response(self, ollama_response: Response) -> Chat message.meta.update({key: value for key, value in json_content.items() if key != "message"}) return message + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: + """ + Converts a list of chunks response required Haystack format. + """ + + replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))] + meta = {key: value for key, value in chunks[0].meta.items() if key != "message"} + + return {"replies": replies, "meta": [meta]} + + def _build_chunk(self, chunk_response: Any) -> StreamingChunk: + """ + Converts the response from the Ollama API to a StreamingChunk. + """ + decoded_chunk = json.loads(chunk_response.decode("utf-8")) + + content = decoded_chunk["message"]["content"] + meta = {key: value for key, value in decoded_chunk.items() if key != "message"} + meta["role"] = decoded_chunk["message"]["role"] + + chunk_message = StreamingChunk(content, meta) + return chunk_message + + def _handle_streaming_response(self, response) -> List[StreamingChunk]: + """ + Handles Streaming response cases + """ + chunks: List[StreamingChunk] = [] + for chunk in response.iter_lines(): + chunk_delta: StreamingChunk = self._build_chunk(chunk) + chunks.append(chunk_delta) + if self.streaming_callback is not None: + self.streaming_callback(chunk_delta) + return chunks + @component.output_types(replies=List[ChatMessage]) def run( self, @@ -100,16 +141,24 @@ def run( Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, etc. See the [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param streaming_callback: + A callback function that will be called with each response chunk in streaming mode. :returns: A dictionary with the following keys: - `replies`: The responses from the model """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - json_payload = self._create_json_payload(messages, generation_kwargs) + stream = self.streaming_callback is not None + + json_payload = self._create_json_payload(messages, stream, generation_kwargs) - response = requests.post(url=self.url, json=json_payload, timeout=self.timeout) + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) # throw error on unsuccessful response response.raise_for_status() + if stream: + chunks: List[StreamingChunk] = self._handle_streaming_response(response) + return self._convert_to_streaming_response(chunks) + return {"replies": [self._build_message_from_ollama_response(response)]} diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index ff6a840b8..e09208bb8 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -41,7 +41,9 @@ def test_init(self): assert component.timeout == 5 def test_create_json_payload(self, chat_messages): - observed = OllamaChatGenerator(model="some_model")._create_json_payload(chat_messages, {"temperature": 0.1}) + observed = OllamaChatGenerator(model="some_model")._create_json_payload( + chat_messages, False, {"temperature": 0.1} + ) expected = { "messages": [ {"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"}, @@ -125,3 +127,27 @@ def test_run_model_unavailable(self): "Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?" ) component.run([message]) + + @pytest.mark.integration + def test_run_with_streaming(self): + streaming_callback = Mock() + chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) + + chat_history = [ + {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, + {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, + {"role": "user", "content": "And what is the second largest?"}, + ] + + chat_messages = [ + ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) + for message in chat_history + ] + + response = chat_generator.run(chat_messages) + + streaming_callback.assert_called() + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert "Manchester" in response["replies"][-1].content or "Glasgow" in response["replies"][-1].content diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md new file mode 100644 index 000000000..6509d1e0f --- /dev/null +++ b/integrations/opensearch/CHANGELOG.md @@ -0,0 +1,105 @@ +# Changelog + +## [integrations/opensearch-v0.9.0] - 2024-08-01 + +### 🚀 Features + +- Support aws authentication with OpenSearchDocumentStore (#920) + +## [integrations/opensearch-v0.8.1] - 2024-07-15 + +### 🚀 Features + +- Add raise_on_failure param to OpenSearch retrievers (#852) +- Add filter_policy to opensearch integration (#822) + +### 🐛 Bug Fixes + +- `OpenSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#895) + +### ⚙️ Miscellaneous Tasks + +- Update ruff invocation to include check parameter (#853) + +## [integrations/opensearch-v0.7.1] - 2024-06-27 + +### 🐛 Bug Fixes + +- Serialization for custom_query in OpenSearch retrievers (#851) +- Support legacy filters with OpenSearchDocumentStore (#850) + +## [integrations/opensearch-v0.7.0] - 2024-06-25 + +### 🚀 Features + +- Defer the database connection to when it's needed (#753) +- Improve `OpenSearchDocumentStore.__init__` arguments (#739) +- Return_embeddings flag for opensearch (#784) +- Add create_index option to OpenSearchDocumentStore (#840) +- Add custom_query param to OpenSearch retrievers (#841) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Fixing opensearch docstrings (#521) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) + +### Opensearch + +- Generate API docs (#324) + +## [integrations/opensearch-v0.2.0] - 2024-01-17 + +### 🐛 Bug Fixes + +- Fix links in docstrings (#188) + + + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + +## [integrations/opensearch-v0.1.1] - 2023-12-05 + +### 🐛 Bug Fixes + +- Fix import and increase version (#77) + + + +## [integrations/opensearch-v0.1.0] - 2023-12-04 + +### 🐛 Bug Fixes + +- Fix license headers + + +## [integrations/opensearch-v0.0.2] - 2023-11-30 + +### 🚀 Features + +- Extend OpenSearch params support (#70) + +### Build + +- Bump OpenSearch integration version to 0.0.2 (#71) + +## [integrations/opensearch-v0.0.1] - 2023-11-30 + +### 🚀 Features + +- [OpenSearch] add document store, BM25Retriever and EmbeddingRetriever (#68) + + diff --git a/integrations/opensearch/pydoc/config.yml b/integrations/opensearch/pydoc/config.yml index dfcb23b5f..7b2e20d83 100644 --- a/integrations/opensearch/pydoc/config.yml +++ b/integrations/opensearch/pydoc/config.yml @@ -17,7 +17,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: OpenSearch integration for Haystack category_slug: integrations-api title: OpenSearch diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 38f93f3a2..b7e5e3da6 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -24,10 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "opensearch-py>=2,<3", -] +dependencies = ["haystack-ai", "opensearch-py>=2,<3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch#readme" @@ -49,50 +44,32 @@ git_describe_command = 'git describe --tags --match="integrations/opensearch-v[0 dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "pytest-xdist", "haystack-pydoc-tools", + "boto3", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -138,9 +115,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -165,26 +148,14 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] minversion = "6.0" -markers = [ - "unit: unit tests", - "integration: integration tests" -] +markers = ["unit: unit tests", "integration: integration tests"] [[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pytest.*", - "opensearchpy.*", -] +module = ["botocore.*", "boto3.*", "haystack.*", "haystack_integrations.*", "pytest.*", "opensearchpy.*"] ignore_missing_imports = true diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index 0ad257b42..640f349b2 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -1,15 +1,26 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +import logging +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore +logger = logging.getLogger(__name__) + @component class OpenSearchBM25Retriever: + """ + Fetches documents from OpenSearchDocumentStore using the keyword-based BM25 algorithm. + + BM25 computes a weighted word overlap between the query string and a document to determine its similarity. + """ + def __init__( self, *, @@ -19,18 +30,54 @@ def __init__( top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + custom_query: Optional[Dict[str, Any]] = None, + raise_on_failure: bool = True, ): """ - Create the OpenSearchBM25Retriever component. - - :param document_store: An instance of OpenSearchDocumentStore. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - :param fuzziness: Fuzziness parameter for full-text queries. Defaults to "AUTO". - :param top_k: Maximum number of Documents to return, defaults to 10 - :param scale_score: Whether to scale the score of retrieved documents between 0 and 1. - This is useful when comparing documents across different indexes. Defaults to False. - :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. - This is useful when searching for short text where even one term can make a difference. Defaults to False. + Creates the OpenSearchBM25Retriever component. + + :param document_store: An instance of OpenSearchDocumentStore to use with the Retriever. + :param filters: Filters to narrow down the search for documents in the Document Store. + :param fuzziness: Fuzziness parameter for full-text queries to apply approximate string matching. + For more information, see [OpenSearch fuzzy query](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). + :param top_k: Maximum number of documents to return. + :param scale_score: If `True`, scales the score of retrieved documents to a range between 0 and 1. + This is useful when comparing documents across different indexes. + :param all_terms_must_match: If `True`, all terms in the query string must be present in the + retrieved documents. This is useful when searching for short text where even one term + can make a difference. + :param filter_policy: Policy to determine how filters are applied. Possible options: + - `replace`: Runtime filters replace initialization filters. Use this policy to change the filtering scope + for specific queries. + - `merge`: Runtime filters are merged with initialization filters. + :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder. + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + An example `run()` method for this `custom_query`: + + ```python + retriever.run(query="Why did the revenue increase?", + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :param raise_on_failure: + Whether to raise an exception if the API call fails. Otherwise log a warning and return an empty list. + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ @@ -44,6 +91,11 @@ def __init__( self._top_k = top_k self._scale_score = scale_score self._all_terms_must_match = all_terms_must_match + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._custom_query = custom_query + self._raise_on_failure = raise_on_failure def to_dict(self) -> Dict[str, Any]: """ @@ -59,6 +111,9 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, scale_score=self._scale_score, document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + custom_query=self._custom_query, + raise_on_failure=self._raise_on_failure, ) @classmethod @@ -75,6 +130,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever": data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -86,23 +146,54 @@ def run( top_k: Optional[int] = None, fuzziness: Optional[str] = None, scale_score: Optional[bool] = None, + custom_query: Optional[Dict[str, Any]] = None, ): """ Retrieve documents using BM25 retrieval. - :param query: The query string - :param filters: Optional filters to narrow down the search space. - :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. - :param top_k: Maximum number of Documents to return. - :param fuzziness: Fuzziness parameter for full-text queries. - :param scale_score: Whether to scale the score of retrieved documents between 0 and 1. + :param query: The query string. + :param filters: Filters applied to the retrieved documents. The way runtime filters are applied depends on + the `filter_policy` specified at Retriever's initialization. + :param all_terms_must_match: If `True`, all terms in the query string must be present in the + retrieved documents. + :param top_k: Maximum number of documents to return. + :param fuzziness: Fuzziness parameter for full-text queries to apply approximate string matching. + For more information, see [OpenSearch fuzzy query](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). + :param scale_score: If `True`, scales the score of retrieved documents to a range between 0 and 1. This is useful when comparing documents across different indexes. + :param custom_query: A custom OpenSearch query. It must include a `$query` and may optionally + include a `$filters` placeholder. + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + **For this custom_query, a sample `run()` could be:** + + ```python + retriever.run(query="Why did the revenue increase?", + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` :returns: A dictionary containing the retrieved documents with the following structure: - documents: List of retrieved Documents. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + if filters is None: filters = self._filters if all_terms_must_match is None: @@ -113,13 +204,29 @@ def run( fuzziness = self._fuzziness if scale_score is None: scale_score = self._scale_score + if custom_query is None: + custom_query = self._custom_query + + docs: List[Document] = [] + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=filters, + fuzziness=fuzziness, + top_k=top_k, + scale_score=scale_score, + all_terms_must_match=all_terms_must_match, + custom_query=custom_query, + ) + except Exception as e: + if self._raise_on_failure: + raise e + else: + logger.warning( + "An error during BM25 retrieval occurred and will be ignored by returning empty results: %s", + str(e), + exc_info=True, + ) - docs = self._document_store._bm25_retrieval( - query=query, - filters=filters, - fuzziness=fuzziness, - top_k=top_k, - scale_score=scale_score, - all_terms_must_match=all_terms_must_match, - ) return {"documents": docs} diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 50b30d7f1..cdf905b97 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -1,19 +1,24 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +import logging +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore +logger = logging.getLogger(__name__) + @component class OpenSearchEmbeddingRetriever: """ - Uses a vector similarity metric to retrieve documents from the OpenSearchDocumentStore. + Retrieves documents from the OpenSearchDocumentStore using a vector similarity metric. - Needs to be connected to the OpenSearchDocumentStore to run. + Must be connected to the OpenSearchDocumentStore to run. """ def __init__( @@ -22,14 +27,56 @@ def __init__( document_store: OpenSearchDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + custom_query: Optional[Dict[str, Any]] = None, + raise_on_failure: bool = True, ): """ Create the OpenSearchEmbeddingRetriever component. - :param document_store: An instance of OpenSearchDocumentStore. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param document_store: An instance of OpenSearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. Possible options: + - `merge`: Runtime filters are merged with initialization filters. + - `replace`: Runtime filters replace initialization filters. Use this policy to change the filtering scope. + :param custom_query: The custom OpenSearch query containing a mandatory `$query_embedding` and + an optional `$filters` placeholder. + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": "$query_embedding", // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + For this `custom_query`, an example `run()` could be: + + ```python + retriever.run(query_embedding=embedding, + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :param raise_on_failure: + If `True`, raises an exception if the API call fails. + If `False`, logs a warning and returns an empty list. + :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ if not isinstance(document_store, OpenSearchDocumentStore): @@ -39,6 +86,11 @@ def __init__( self._document_store = document_store self._filters = filters or {} self._top_k = top_k + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._custom_query = custom_query + self._raise_on_failure = raise_on_failure def to_dict(self) -> Dict[str, Any]: """ @@ -52,6 +104,9 @@ def to_dict(self) -> Dict[str, Any]: filters=self._filters, top_k=self._top_k, document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + custom_query=self._custom_query, + raise_on_failure=self._raise_on_failure, ) @classmethod @@ -68,28 +123,92 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + custom_query: Optional[Dict[str, Any]] = None, + ): """ Retrieve documents using a vector similarity metric. :param query_embedding: Embedding of the query. - :param filters: Optional filters to narrow down the search space. - :param top_k: Maximum number of Documents to return. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever + returns `top_k` matching documents. + The way runtime filters are applied depends on the `filter_policy` selected when initializing the Retriever. + :param top_k: Maximum number of documents to return. + :param custom_query: A custom OpenSearch query containing a mandatory `$query_embedding` and an + optional `$filters` placeholder. + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": "$query_embedding", // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + + For this `custom_query`, an example `run()` could be: + + ```python + retriever.run(query_embedding=embedding, + filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` + :returns: Dictionary with key "documents" containing the retrieved Documents. - documents: List of Document similar to `query_embedding`. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + top_k = top_k or self._top_k if filters is None: filters = self._filters if top_k is None: top_k = self._top_k + if custom_query is None: + custom_query = self._custom_query + + docs: List[Document] = [] + + try: + docs = self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + custom_query=custom_query, + ) + except Exception as e: + if self._raise_on_failure: + raise e + else: + logger.warning( + "An error during embedding retrieval occurred and will be ignored by returning empty results: %s", + str(e), + exc_info=True, + ) - docs = self._document_store._embedding_retrieval( - query_embedding=query_embedding, - filters=filters, - top_k=top_k, - ) return {"documents": docs} diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py new file mode 100644 index 000000000..8249c16ca --- /dev/null +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass, field, fields +from typing import Any, Dict, Optional + +from haystack import default_from_dict, default_to_dict +from haystack.document_stores.errors import DocumentStoreError +from haystack.lazy_imports import LazyImport +from haystack.utils.auth import Secret, deserialize_secrets_inplace +from opensearchpy import Urllib3AWSV4SignerAuth + +with LazyImport("Run 'pip install \"boto3\"' to install boto3.") as boto3_import: + import boto3 + from botocore.exceptions import BotoCoreError + + +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + + +class AWSConfigurationError(DocumentStoreError): + """Exception raised when AWS is not configured correctly""" + + +def _get_aws_session( + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, +): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :returns: The created AWS session. + """ + boto3_import.check() + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, + ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + raise AWSConfigurationError(msg) from e + + +@dataclass() +class AWSAuth: + """ + Auth credentials for AWS OpenSearch services. + + This class works as a thin wrapper around the `Urllib3AWSV4SignerAuth` class from the `opensearch-py` library. + It facilitates the creation of the `Urllib3AWSV4SignerAuth` by making use of Haystack secrets and taking care of + the necessary `Urllib3AWSV4SignerAuth` creation steps including boto3 Sessions and boto3 credentials. + """ + + aws_access_key_id: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False) + ) + aws_secret_access_key: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_SECRET_ACCESS_KEY", strict=False) + ) + aws_session_token: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_SESSION_TOKEN", strict=False) + ) + aws_region_name: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_DEFAULT_REGION", strict=False) + ) + aws_profile_name: Optional[Secret] = field(default_factory=lambda: Secret.from_env_var("AWS_PROFILE", strict=False)) + aws_service: str = field(default="es") + + def __post_init__(self) -> None: + """ + Initializes the AWSAuth object. + """ + self._urllib3_aws_v4_signer_auth = self._get_urllib3_aws_v4_signer_auth() + + def to_dict(self) -> Dict[str, Any]: + """ + Converts the object to a dictionary representation for serialization. + """ + _fields = {} + for _field in fields(self): + field_value = getattr(self, _field.name) + if _field.type == Optional[Secret]: + _fields[_field.name] = field_value.to_dict() if field_value is not None else None + else: + _fields[_field.name] = field_value + + return default_to_dict(self, **_fields) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Optional["AWSAuth"]: + """ + Converts a dictionary representation to an AWSAuth object. + """ + init_parameters = data.get("init_parameters", {}) + deserialize_secrets_inplace( + init_parameters, + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) + return default_from_dict(cls, data) + + def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]: + """ + Signs the request and returns headers. + + This method is executed by Urllib3 when making a request to the OpenSearch service. + + :param method: HTTP method + :param url: URL + :param body: Body + """ + return self._urllib3_aws_v4_signer_auth(method, url, body) + + def _get_urllib3_aws_v4_signer_auth(self) -> Urllib3AWSV4SignerAuth: + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + region_name = resolve_secret(self.aws_region_name) + session = _get_aws_session( + aws_access_key_id=resolve_secret(self.aws_access_key_id), + aws_secret_access_key=resolve_secret(self.aws_secret_access_key), + aws_session_token=resolve_secret(self.aws_session_token), + aws_region_name=region_name, + aws_profile_name=resolve_secret(self.aws_profile_name), + ) + credentials = session.get_credentials() + return Urllib3AWSV4SignerAuth(credentials, region_name, self.aws_service) + except Exception as exception: + msg = ( + "Could not connect to AWS OpenSearch. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AWSConfigurationError(msg) from exception diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 1da495228..465897608 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -10,6 +10,7 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils.filters import convert +from haystack_integrations.document_stores.opensearch.auth import AWSAuth from haystack_integrations.document_stores.opensearch.filters import normalize_filters from opensearchpy import OpenSearch from opensearchpy.helpers import bulk @@ -27,6 +28,9 @@ # all be mapped to scores ~1. BM25_SCALING_FACTOR = 8 +DEFAULT_SETTINGS = {"index.knn": True} +DEFAULT_MAX_CHUNK_BYTES = 100 * 1024 * 1024 + class OpenSearchDocumentStore: def __init__( @@ -34,58 +38,140 @@ def __init__( *, hosts: Optional[Hosts] = None, index: str = "default", + max_chunk_bytes: int = DEFAULT_MAX_CHUNK_BYTES, + embedding_dim: int = 768, + return_embedding: bool = False, + method: Optional[Dict[str, Any]] = None, + mappings: Optional[Dict[str, Any]] = None, + settings: Optional[Dict[str, Any]] = DEFAULT_SETTINGS, + create_index: bool = True, + http_auth: Any = None, + use_ssl: Optional[bool] = None, + verify_certs: Optional[bool] = None, + timeout: Optional[int] = None, **kwargs, ): """ Creates a new OpenSearchDocumentStore instance. - For more information on connection parameters, see the [official OpenSearch documentation](https://opensearch.org/docs/latest/clients/python-low-level/#connecting-to-opensearch) + The ``embeddings_dim``, ``method``, ``mappings``, and ``settings`` arguments are only used if the index does not + exists and needs to be created. If the index already exists, its current configurations will be used. - For the full list of supported kwargs, see the [official OpenSearch reference](https://opensearch-project.github.io/opensearch-py/api-ref/clients/opensearch_client.html) + For more information on connection parameters, see the [official OpenSearch documentation](https://opensearch.org/docs/latest/clients/python-low-level/#connecting-to-opensearch) :param hosts: List of hosts running the OpenSearch client. Defaults to None :param index: Name of index in OpenSearch, if it doesn't exist it will be created. Defaults to "default" - :param **kwargs: Optional arguments that ``OpenSearch`` takes. + :param max_chunk_bytes: Maximum size of the requests in bytes. Defaults to 100MB + :param embedding_dim: Dimension of the embeddings. Defaults to 768 + :param return_embedding: + Whether to return the embedding of the retrieved Documents. + :param method: The method definition of the underlying configuration of the approximate k-NN algorithm. Please + see the [official OpenSearch docs](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#method-definitions) + for more information. Defaults to None + :param mappings: The mapping of how the documents are stored and indexed. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/field-types/) + for more information. If None, it uses the embedding_dim and method arguments to create default mappings. + Defaults to None + :param settings: The settings of the index to be created. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#index-settings) + for more information. Defaults to {"index.knn": True} + :param create_index: Whether to create the index if it doesn't exist. Defaults to True + :param http_auth: http_auth param passed to the underying connection class. + For basic authentication with default connection class `Urllib3HttpConnection` this can be + - a tuple of (username, password) + - a list of [username, password] + - a string of "username:password" + For AWS authentication with `Urllib3HttpConnection` pass an instance of `AWSAuth`. + Defaults to None + :param use_ssl: Whether to use SSL. Defaults to None + :param verify_certs: Whether to verify certificates. Defaults to None + :param timeout: Timeout in seconds. Defaults to None + :param **kwargs: Optional arguments that ``OpenSearch`` takes. For the full list of supported kwargs, + see the [official OpenSearch reference](https://opensearch-project.github.io/opensearch-py/api-ref/clients/opensearch_client.html) """ + self._client = None self._hosts = hosts - self._client = OpenSearch(hosts, **kwargs) self._index = index + self._max_chunk_bytes = max_chunk_bytes + self._embedding_dim = embedding_dim + self._return_embedding = return_embedding + self._method = method + self._mappings = mappings or self._get_default_mappings() + self._settings = settings + self._create_index = create_index + self._http_auth = http_auth + self._use_ssl = use_ssl + self._verify_certs = verify_certs + self._timeout = timeout self._kwargs = kwargs - # Check client connection, this will raise if not connected - self._client.info() - - # configure mapping for the embedding field - embedding_dim = kwargs.get("embedding_dim", 768) - method = kwargs.get("method", None) - - mappings: Dict[str, Any] = { + def _get_default_mappings(self) -> Dict[str, Any]: + default_mappings: Dict[str, Any] = { "properties": { - "embedding": {"type": "knn_vector", "index": True, "dimension": embedding_dim}, + "embedding": {"type": "knn_vector", "index": True, "dimension": self._embedding_dim}, "content": {"type": "text"}, }, "dynamic_templates": [ { "strings": { "match_mapping_type": "string", - "mapping": { - "type": "keyword", - }, + "mapping": {"type": "keyword"}, } } ], } - if method: - mappings["properties"]["embedding"]["method"] = method + if self._method: + default_mappings["properties"]["embedding"]["method"] = self._method + return default_mappings + + @property + def client(self) -> OpenSearch: + if not self._client: + self._client = OpenSearch( + hosts=self._hosts, + http_auth=self._http_auth, + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, + **self._kwargs, + ) + + if self._client.indices.exists(index=self._index): # type:ignore + logger.debug( + "The index '%s' already exists. The `embedding_dim`, `method`, `mappings`, and " + "`settings` values will be ignored.", + self._index, + ) + elif self._create_index: + # Create the index if it doesn't exist + body = {"mappings": self._mappings, "settings": self._settings} + self._client.indices.create(index=self._index, body=body) # type:ignore + return self._client + + def create_index( + self, + index: Optional[str] = None, + mappings: Optional[Dict[str, Any]] = None, + settings: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Creates an index in OpenSearch. - mappings = kwargs.get("mappings", mappings) - settings = kwargs.get("settings", {"index.knn": True}) + Note that this method ignores the `create_index` argument from the constructor. - body = {"mappings": mappings, "settings": settings} + :param index: Name of the index to create. If None, the index name from the constructor is used. + :param mappings: The mapping of how the documents are stored and indexed. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/field-types/) + for more information. If None, the mappings from the constructor are used. + :param settings: The settings of the index to be created. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#index-settings) + for more information. If None, the settings from the constructor are used. + """ + if not index: + index = self._index + if not mappings: + mappings = self._mappings + if not settings: + settings = self._settings - # Create the index if it doesn't exist - if not self._client.indices.exists(index=index): - self._client.indices.create(index=index, body=body) + if not self.client.indices.exists(index=index): + self.client.indices.create(index=index, body={"mappings": mappings, "settings": settings}) def to_dict(self) -> Dict[str, Any]: # This is not the best solution to serialise this class but is the fastest to implement. @@ -101,6 +187,17 @@ def to_dict(self) -> Dict[str, Any]: self, hosts=self._hosts, index=self._index, + max_chunk_bytes=self._max_chunk_bytes, + embedding_dim=self._embedding_dim, + method=self._method, + mappings=self._mappings, + settings=self._settings, + create_index=self._create_index, + return_embedding=self._return_embedding, + http_auth=self._http_auth.to_dict() if isinstance(self._http_auth, AWSAuth) else self._http_auth, + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, **self._kwargs, ) @@ -115,19 +212,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchDocumentStore": :returns: Deserialized component. """ + if http_auth := data.get("init_parameters", {}).get("http_auth"): + if isinstance(http_auth, dict): + data["init_parameters"]["http_auth"] = AWSAuth.from_dict(http_auth) + return default_from_dict(cls, data) def count_documents(self) -> int: """ Returns how many documents are present in the document store. """ - return self._client.count(index=self._index)["count"] + return self.client.count(index=self._index)["count"] def _search_documents(self, **kwargs) -> List[Document]: """ Calls the OpenSearch client's search method and handles pagination. """ - res = self._client.search( + res = self.client.search( index=self._index, body=kwargs, ) @@ -162,7 +263,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D action = "index" if policy == DuplicatePolicy.OVERWRITE else "create" documents_written, errors = bulk( - client=self._client, + client=self.client, actions=( { "_op_type": action, @@ -174,6 +275,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D refresh="wait_for", index=self._index, raise_on_error=False, + max_chunk_bytes=self._max_chunk_bytes, ) if errors: @@ -225,11 +327,12 @@ def delete_documents(self, document_ids: List[str]) -> None: """ bulk( - client=self._client, + client=self.client, actions=({"_op_type": "delete", "_id": id_} for id_ in document_ids), refresh="wait_for", index=self._index, raise_on_error=False, + max_chunk_bytes=self._max_chunk_bytes, ) def _bm25_retrieval( @@ -241,6 +344,7 @@ def _bm25_retrieval( top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, + custom_query: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ OpenSearch by defaults uses BM25 search algorithm. @@ -251,8 +355,6 @@ def _bm25_retrieval( `OpenSearchDocumentStore` nor called directly. `OpenSearchBM25Retriever` uses this method directly and is the public interface for it. - `query` must be a non empty string, otherwise a `ValueError` will be raised. - :param query: String to search in saved Documents' text. :param filters: Optional filters to narrow down the search space. :param fuzziness: Fuzziness parameter passed to OpenSearch, defaults to "AUTO". see the official documentation @@ -260,41 +362,71 @@ def _bm25_retrieval( :param top_k: Maximum number of Documents to return, defaults to 10 :param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False :param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False - :raises ValueError: If `query` is an empty string + :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder + + **An example custom_query:** + + ```python + { + "query": { + "bool": { + "should": [{"multi_match": { + "query": "$query", // mandatory query placeholder + "type": "most_fields", + "fields": ["content", "title"]}}], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + :returns: List of Document that match `query` """ + if filters and "operator" not in filters and "conditions" not in filters: + filters = convert(filters) if not query: - msg = "query must be a non empty string" - raise ValueError(msg) + body: Dict[str, Any] = {"query": {"bool": {"must": {"match_all": {}}}}} + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + if isinstance(custom_query, dict): + body = self._render_custom_query(custom_query, {"$query": query, "$filters": normalize_filters(filters)}) - operator = "AND" if all_terms_must_match else "OR" - body: Dict[str, Any] = { - "size": top_k, - "query": { - "bool": { - "must": [ - { - "multi_match": { - "query": query, - "fuzziness": fuzziness, - "type": "most_fields", - "operator": operator, + else: + operator = "AND" if all_terms_must_match else "OR" + body = { + "query": { + "bool": { + "must": [ + { + "multi_match": { + "query": query, + "fuzziness": fuzziness, + "type": "most_fields", + "operator": operator, + } } - } - ] - } - }, - } + ] + } + }, + } - if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + body["size"] = top_k + + # For some applications not returning the embedding can save a lot of bandwidth + # if you don't need this data not retrieving it can be a good idea + if not self._return_embedding: + body["_source"] = {"excludes": ["embedding"]} documents = self._search_documents(**body) if scale_score: for doc in documents: - doc.score = float(1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR)))) + doc.score = float(1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR)))) # type:ignore return documents @@ -304,6 +436,7 @@ def _embedding_retrieval( *, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + custom_query: Optional[Dict[str, Any]] = None, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -317,34 +450,88 @@ def _embedding_retrieval( :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. :param top_k: Maximum number of Documents to return, defaults to 10 + :param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder + + **An example custom_query:** + ```python + { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": "$query_embedding", // mandatory query placeholder + "k": 10000, + } + } + } + ], + "filter": "$filters" // optional filter placeholder + } + } + } + ``` + :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` """ + if filters and "operator" not in filters and "conditions" not in filters: + filters = convert(filters) if not query_embedding: msg = "query_embedding must be a non-empty list of floats" raise ValueError(msg) - body: Dict[str, Any] = { - "query": { - "bool": { - "must": [ - { - "knn": { - "embedding": { - "vector": query_embedding, - "k": top_k, + if isinstance(custom_query, dict): + body = self._render_custom_query( + custom_query, {"$query_embedding": query_embedding, "$filters": normalize_filters(filters)} + ) + + else: + body = { + "query": { + "bool": { + "must": [ + { + "knn": { + "embedding": { + "vector": query_embedding, + "k": top_k, + } } } - } - ], - } - }, - "size": top_k, - } + ], + } + }, + } - if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if filters: + body["query"]["bool"]["filter"] = normalize_filters(filters) + + body["size"] = top_k + + # For some applications not returning the embedding can save a lot of bandwidth + # if you don't need this data not retrieving it can be a good idea + if not self._return_embedding: + body["_source"] = {"excludes": ["embedding"]} docs = self._search_documents(**body) return docs + + def _render_custom_query(self, custom_query: Any, substitutions: Dict[str, Any]) -> Any: + """ + Recursively replaces the placeholders in the custom_query with the actual values. + + :param custom_query: The custom query to replace the placeholders in. + :param substitutions: The dictionary containing the actual values to replace the placeholders with. + :returns: The custom query with the placeholders replaced. + """ + if isinstance(custom_query, dict): + return {key: self._render_custom_query(value, substitutions) for key, value in custom_query.items()} + elif isinstance(custom_query, list): + return [self._render_custom_query(entry, substitutions) for entry in custom_query] + elif isinstance(custom_query, str): + return substitutions.get(custom_query, custom_query) + + return custom_query diff --git a/integrations/opensearch/tests/test_auth.py b/integrations/opensearch/tests/test_auth.py new file mode 100644 index 000000000..25bda7d66 --- /dev/null +++ b/integrations/opensearch/tests/test_auth.py @@ -0,0 +1,113 @@ +from unittest.mock import Mock, patch + +import pytest +from haystack_integrations.document_stores.opensearch.auth import AWSAuth +from opensearchpy import Urllib3AWSV4SignerAuth + + +class TestAWSAuth: + @pytest.fixture(autouse=True) + def mock_boto3_session(self): + with patch("boto3.Session") as mock_client: + yield mock_client + + @pytest.fixture(autouse=True) + def set_aws_env_variables(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token") + monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region") + monkeypatch.setenv("AWS_PROFILE", "some_fake_profile") + + def test_init(self, mock_boto3_session): + aws_auth = AWSAuth() + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + def test_to_dict(self): + aws_auth = AWSAuth() + res = aws_auth.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_service": "es", + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_service": "es", + }, + } + aws_auth = AWSAuth.from_dict(data) + assert aws_auth.aws_access_key_id.resolve_value() == "some_fake_id" + assert aws_auth.aws_secret_access_key.resolve_value() == "some_fake_key" + assert aws_auth.aws_session_token.resolve_value() == "some_fake_token" + assert aws_auth.aws_region_name.resolve_value() == "fake_region" + assert aws_auth.aws_profile_name.resolve_value() == "some_fake_profile" + assert aws_auth.aws_service == "es" + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + + def test_from_dict_no_init_parameters(self): + data = {"type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth"} + aws_auth = AWSAuth.from_dict(data) + assert aws_auth.aws_access_key_id.resolve_value() == "some_fake_id" + assert aws_auth.aws_secret_access_key.resolve_value() == "some_fake_key" + assert aws_auth.aws_session_token.resolve_value() == "some_fake_token" + assert aws_auth.aws_region_name.resolve_value() == "fake_region" + assert aws_auth.aws_profile_name.resolve_value() == "some_fake_profile" + assert aws_auth.aws_service == "es" + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + + def test_from_dict_disable_env_variables(self): + data = { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": None, + "aws_secret_access_key": None, + "aws_session_token": None, + "aws_service": "aoss", + }, + } + aws_auth = AWSAuth.from_dict(data) + assert aws_auth.aws_access_key_id is None + assert aws_auth.aws_secret_access_key is None + assert aws_auth.aws_session_token is None + assert aws_auth.aws_region_name.resolve_value() == "fake_region" + assert aws_auth.aws_profile_name.resolve_value() == "some_fake_profile" + assert aws_auth.aws_service == "aoss" + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + + @patch("haystack_integrations.document_stores.opensearch.auth.AWSAuth._get_urllib3_aws_v4_signer_auth") + def test_call(self, _get_urllib3_aws_v4_signer_auth_mock): + signer_auth_mock = Mock(spec=Urllib3AWSV4SignerAuth) + _get_urllib3_aws_v4_signer_auth_mock.return_value = signer_auth_mock + aws_auth = AWSAuth() + aws_auth(method="GET", url="http://some.url", body="some body") + signer_auth_mock.assert_called_once_with("GET", "http://some.url", "some body") diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index 3f84f41a9..c015d360a 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -3,9 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.opensearch import OpenSearchBM25Retriever from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore +from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES def test_init_default(): @@ -15,20 +18,46 @@ def test_init_default(): assert retriever._filters == {} assert retriever._top_k == 10 assert not retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = OpenSearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + OpenSearchBM25Retriever(document_store=mock_store, filter_policy="unknown") @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict(_mock_opensearch_client): document_store = OpenSearchDocumentStore(hosts="some fake host") - retriever = OpenSearchBM25Retriever(document_store=document_store) + retriever = OpenSearchBM25Retriever(document_store=document_store, custom_query={"some": "custom query"}) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", "init_parameters": { "document_store": { "init_parameters": { + "embedding_dim": 768, "hosts": "some fake host", "index": "default", + "mappings": { + "dynamic_templates": [ + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} + ], + "properties": { + "content": {"type": "text"}, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, }, "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", }, @@ -36,6 +65,9 @@ def test_to_dict(_mock_opensearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": False, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": True, }, } @@ -53,6 +85,9 @@ def test_from_dict(_mock_opensearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": True, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, }, } retriever = OpenSearchBM25Retriever.from_dict(data) @@ -61,6 +96,28 @@ def test_from_dict(_mock_opensearch_client): assert retriever._fuzziness == "AUTO" assert retriever._top_k == 10 assert retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._custom_query == {"some": "custom query"} + assert retriever._raise_on_failure is False + + # For backwards compatibility with older versions of the retriever without a filter policy + data = { + "type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "fuzziness": "AUTO", + "top_k": 10, + "scale_score": True, + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchBM25Retriever.from_dict(data) + assert retriever._filter_policy == FilterPolicy.REPLACE def test_run(): @@ -75,6 +132,7 @@ def test_run(): top_k=10, scale_score=False, all_terms_must_match=False, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -91,6 +149,7 @@ def test_run_init_params(): scale_score=True, top_k=11, fuzziness="1", + custom_query={"some": "custom query"}, ) res = retriever.run(query="some query") mock_store._bm25_retrieval.assert_called_once_with( @@ -100,6 +159,7 @@ def test_run_init_params(): top_k=11, scale_score=True, all_terms_must_match=True, + custom_query={"some": "custom query"}, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -132,7 +192,18 @@ def test_run_time_params(): top_k=9, scale_score=False, all_terms_must_match=False, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 assert res["documents"][0].content == "Test doc" + + +def test_run_ignore_errors(caplog): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._bm25_retrieval.side_effect = Exception("Some error") + retriever = OpenSearchBM25Retriever(document_store=mock_store, raise_on_failure=False) + res = retriever.run(query="some query") + assert len(res) == 1 + assert res["documents"] == [] + assert "Some error" in caplog.text diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 8e984953d..287c24f63 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -10,7 +10,10 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.utils.auth import Secret from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore +from haystack_integrations.document_stores.opensearch.auth import AWSAuth +from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES from opensearchpy.exceptions import RequestError @@ -21,8 +24,25 @@ def test_to_dict(_mock_opensearch_client): assert res == { "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", "init_parameters": { + "embedding_dim": 768, "hosts": "some hosts", "index": "default", + "mappings": { + "dynamic_templates": [{"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}}], + "properties": { + "content": {"type": "text"}, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, }, } @@ -34,11 +54,213 @@ def test_from_dict(_mock_opensearch_client): "init_parameters": { "hosts": "some hosts", "index": "default", + "max_chunk_bytes": 1000, + "embedding_dim": 1536, + "create_index": False, + "return_embedding": True, + "aws_service": "es", + "http_auth": ("admin", "admin"), + "use_ssl": True, + "verify_certs": True, + "timeout": 60, }, } document_store = OpenSearchDocumentStore.from_dict(data) assert document_store._hosts == "some hosts" assert document_store._index == "default" + assert document_store._max_chunk_bytes == 1000 + assert document_store._embedding_dim == 1536 + assert document_store._method is None + assert document_store._mappings == { + "properties": { + "embedding": {"type": "knn_vector", "index": True, "dimension": 1536}, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "match_mapping_type": "string", + "mapping": {"type": "keyword"}, + } + } + ], + } + assert document_store._settings == {"index.knn": True} + assert document_store._return_embedding is True + assert document_store._create_index is False + assert document_store._http_auth == ("admin", "admin") + assert document_store._use_ssl is True + assert document_store._verify_certs is True + assert document_store._timeout == 60 + + +@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") +def test_init_is_lazy(_mock_opensearch_client): + OpenSearchDocumentStore(hosts="testhost") + _mock_opensearch_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") +def test_get_default_mappings(_mock_opensearch_client): + store = OpenSearchDocumentStore(hosts="testhost", embedding_dim=1536, method={"name": "hnsw"}) + assert store._mappings["properties"]["embedding"] == { + "type": "knn_vector", + "index": True, + "dimension": 1536, + "method": {"name": "hnsw"}, + } + + +class TestAuth: + @pytest.fixture(autouse=True) + def mock_boto3_session(self): + with patch("boto3.Session") as mock_client: + yield mock_client + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="testhost", http_auth=("user", "pw")) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ("user", "pw") + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_without_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="testhost") + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] is None + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_aws_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore( + hosts="testhost", + http_auth=AWSAuth(aws_region_name=Secret.from_token("dummy-region")), + use_ssl=True, + verify_certs=True, + ) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert isinstance(_mock_opensearch_client.call_args[1]["http_auth"], AWSAuth) + assert _mock_opensearch_client.call_args[1]["use_ssl"] is True + assert _mock_opensearch_client.call_args[1]["verify_certs"] is True + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore.from_dict( + { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": ["user", "pw"], + "use_ssl": True, + "verify_certs": True, + }, + } + ) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pw"] + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") + document_store = OpenSearchDocumentStore.from_dict( + { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": {}, + }, + "use_ssl": True, + "verify_certs": True, + }, + } + ) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert isinstance(_mock_opensearch_client.call_args[1]["http_auth"], AWSAuth) + assert _mock_opensearch_client.call_args[1]["use_ssl"] is True + assert _mock_opensearch_client.call_args[1]["verify_certs"] is True + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="some hosts", http_auth=("user", "pw")) + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} + ], + "properties": { + "content": {"type": "text"}, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": ("user", "pw"), + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") + document_store = OpenSearchDocumentStore(hosts="some hosts", http_auth=AWSAuth()) + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} + ], + "properties": { + "content": {"type": "text"}, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_service": "es", + }, + }, + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } @pytest.mark.integration @@ -67,7 +289,31 @@ def document_store(self, request): method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, ) yield store - store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + + @pytest.fixture + def document_store_readonly(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = ["https://localhost:9200"] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=768, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + create_index=False, + ) + store.client.cluster.put_settings(body={"transient": {"action.auto_create_index": False}}) + yield store + store.client.cluster.put_settings(body={"transient": {"action.auto_create_index": True}}) + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) @pytest.fixture def document_store_embedding_dim_4(self, request): @@ -88,7 +334,7 @@ def document_store_embedding_dim_4(self, request): method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, ) yield store - store._client.indices.delete(index=index, params={"ignore": [400, 404]}) + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ @@ -121,6 +367,15 @@ def test_write_documents(self, document_store: OpenSearchDocumentStore): with pytest.raises(DuplicateDocumentError): document_store.write_documents(docs, DuplicatePolicy.FAIL) + def test_write_documents_readonly(self, document_store_readonly: OpenSearchDocumentStore): + docs = [Document(id="1")] + with pytest.raises(DocumentStoreError, match="index_not_found_exception"): + document_store_readonly.write_documents(docs) + + def test_create_index(self, document_store_readonly: OpenSearchDocumentStore): + document_store_readonly.create_index() + assert document_store_readonly.client.indices.exists(index=document_store_readonly._index) + def test_bm25_retrieval(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ @@ -248,6 +503,227 @@ def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentS assert "functional" in res[1].content assert "functional" in res[2].content + def test_bm25_retrieval_with_filters(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + res = document_store._bm25_retrieval( + "programming", + top_k=10, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 5 + retrieved_ids = sorted([doc.id for doc in res]) + assert retrieved_ids == ["1", "2", "3", "4", "5"] + + def test_bm25_retrieval_with_legacy_filters(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + res = document_store._bm25_retrieval( + "programming", + top_k=10, + filters={"language_type": "functional"}, + ) + assert len(res) == 5 + retrieved_ids = sorted([doc.id for doc in res]) + assert retrieved_ids == ["1", "2", "3", "4", "5"] + + def test_bm25_retrieval_with_custom_query(self, document_store: OpenSearchDocumentStore): + document_store.write_documents( + [ + Document( + content="Haskell is a functional programming language", + meta={"likes": 100000, "language_type": "functional"}, + id="1", + ), + Document( + content="Lisp is a functional programming language", + meta={"likes": 10000, "language_type": "functional"}, + id="2", + ), + Document( + content="Exilir is a functional programming language", + meta={"likes": 1000, "language_type": "functional"}, + id="3", + ), + Document( + content="F# is a functional programming language", + meta={"likes": 100, "language_type": "functional"}, + id="4", + ), + Document( + content="C# is a functional programming language", + meta={"likes": 10, "language_type": "functional"}, + id="5", + ), + Document( + content="C++ is an object oriented programming language", + meta={"likes": 100000, "language_type": "object_oriented"}, + id="6", + ), + Document( + content="Dart is an object oriented programming language", + meta={"likes": 10000, "language_type": "object_oriented"}, + id="7", + ), + Document( + content="Go is an object oriented programming language", + meta={"likes": 1000, "language_type": "object_oriented"}, + id="8", + ), + Document( + content="Python is a object oriented programming language", + meta={"likes": 100, "language_type": "object_oriented"}, + id="9", + ), + Document( + content="Ruby is a object oriented programming language", + meta={"likes": 10, "language_type": "object_oriented"}, + id="10", + ), + Document( + content="PHP is a object oriented programming language", + meta={"likes": 1, "language_type": "object_oriented"}, + id="11", + ), + ] + ) + + custom_query = { + "query": { + "function_score": { + "query": {"bool": {"must": {"match": {"content": "$query"}}, "filter": "$filters"}}, + "field_value_factor": {"field": "likes", "factor": 0.1, "modifier": "log1p", "missing": 0}, + } + } + } + + res = document_store._bm25_retrieval( + "functional", + top_k=3, + custom_query=custom_query, + filters={"field": "language_type", "operator": "==", "value": "functional"}, + ) + assert len(res) == 3 + assert "1" == res[0].id + assert "2" == res[1].id + assert "3" == res[2].id + def test_embedding_retrieval(self, document_store_embedding_dim_4: OpenSearchDocumentStore): docs = [ Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), @@ -283,6 +759,27 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4: assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_with_legacy_filters(self, document_store_embedding_dim_4: OpenSearchDocumentStore): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + filters = {"meta_field": "custom_value"} + # we set top_k=3, to make the test pass as we are not sure whether efficient filtering is supported for nmslib + # TODO: remove top_k=3, when efficient filtering is supported for nmslib + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3, filters=filters + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. @@ -299,6 +796,31 @@ def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: Op ) assert len(results) == 11 + def test_embedding_retrieval_with_custom_query(self, document_store_embedding_dim_4: OpenSearchDocumentStore): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4.write_documents(docs) + + custom_query = { + "query": { + "bool": {"must": [{"knn": {"embedding": {"vector": "$query_embedding", "k": 3}}}], "filter": "$filters"} + } + } + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + results = document_store_embedding_dim_4._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters=filters, custom_query=custom_query + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_query_documents_different_embedding_sizes( self, document_store_embedding_dim_4: OpenSearchDocumentStore ): @@ -333,3 +855,60 @@ def test_write_documents_with_badly_formatted_bulk_errors(self, mock_bulk, docum with pytest.raises(DocumentStoreError) as e: document_store.write_documents([Document(content="Hello world")]) e.match(f"{error}") + + @patch("haystack_integrations.document_stores.opensearch.document_store.bulk") + def test_write_documents_max_chunk_bytes(self, mock_bulk, document_store): + mock_bulk.return_value = (1, []) + document_store.write_documents([Document(content="Hello world")]) + + assert mock_bulk.call_args.kwargs["max_chunk_bytes"] == DEFAULT_MAX_CHUNK_BYTES + + @pytest.fixture + def document_store_no_embbding_returned(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = ["https://localhost:9200"] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + return_embedding=False, + method={"space_type": "cosinesimil", "engine": "nmslib", "name": "hnsw"}, + ) + yield store + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + + def test_embedding_retrieval_but_dont_return_embeddings_for_embedding_retrieval( + self, document_store_no_embbding_returned: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), + ] + document_store_no_embbding_returned.write_documents(docs) + results = document_store_no_embbding_returned._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={} + ) + assert len(results) == 2 + assert results[0].embedding is None + + def test_embedding_retrieval_but_dont_return_embeddings_for_bm25_retrieval( + self, document_store_no_embbding_returned: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), + ] + document_store_no_embbding_returned.write_documents(docs) + results = document_store_no_embbding_returned._bm25_retrieval("document", top_k=2) + assert len(results) == 2 + assert results[0].embedding is None diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 0190ca208..e52a099c8 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -3,9 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.opensearch import OpenSearchEmbeddingRetriever from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore +from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES def test_init_default(): @@ -14,12 +17,19 @@ def test_init_default(): assert retriever._document_store == mock_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict(_mock_opensearch_client): document_store = OpenSearchDocumentStore(hosts="some fake host") - retriever = OpenSearchEmbeddingRetriever(document_store=document_store) + retriever = OpenSearchEmbeddingRetriever(document_store=document_store, custom_query={"some": "custom query"}) res = retriever.to_dict() type_s = "haystack_integrations.components.retrievers.opensearch.embedding_retriever.OpenSearchEmbeddingRetriever" assert res == { @@ -27,13 +37,50 @@ def test_to_dict(_mock_opensearch_client): "init_parameters": { "document_store": { "init_parameters": { + "embedding_dim": 768, "hosts": "some fake host", "index": "default", + "mappings": { + "dynamic_templates": [ + { + "strings": { + "mapping": { + "type": "keyword", + }, + "match_mapping_type": "string", + }, + }, + ], + "properties": { + "content": { + "type": "text", + }, + "embedding": { + "dimension": 768, + "index": True, + "type": "knn_vector", + }, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": { + "index.knn": True, + }, + "return_embedding": False, + "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, }, "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", }, "filters": {}, "top_k": 10, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": True, }, } @@ -50,12 +97,35 @@ def test_from_dict(_mock_opensearch_client): }, "filters": {}, "top_k": 10, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, }, } retriever = OpenSearchEmbeddingRetriever.from_dict(data) assert retriever._document_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._custom_query == {"some": "custom query"} + assert retriever._raise_on_failure is False + assert retriever._filter_policy == FilterPolicy.REPLACE + + # For backwards compatibility with older versions of the retriever without a filter policy + data = { + "type": type_s, + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchEmbeddingRetriever.from_dict(data) + assert retriever._filter_policy == FilterPolicy.REPLACE def test_run(): @@ -67,6 +137,7 @@ def test_run(): query_embedding=[0.5, 0.7], filters={}, top_k=10, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -77,12 +148,15 @@ def test_run(): def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] - retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) + retriever = OpenSearchEmbeddingRetriever( + document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" + ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], filters={"from": "init"}, top_k=11, + custom_query="custom_query", ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -99,8 +173,19 @@ def test_run_time_params(): query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, + custom_query=None, ) assert len(res) == 1 assert len(res["documents"]) == 1 assert res["documents"][0].content == "Test doc" assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_ignore_errors(caplog): + mock_store = Mock(spec=OpenSearchDocumentStore) + mock_store._embedding_retrieval.side_effect = Exception("Some error") + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, raise_on_failure=False) + res = retriever.run(query_embedding=[0.5, 0.7]) + assert len(res) == 1 + assert res["documents"] == [] + assert "Some error" in caplog.text diff --git a/integrations/optimum/CHANGELOG.md b/integrations/optimum/CHANGELOG.md new file mode 100644 index 000000000..6699bef7a --- /dev/null +++ b/integrations/optimum/CHANGELOG.md @@ -0,0 +1,37 @@ +# Changelog + +## [integrations/optimum-v0.1.1] - 2024-07-04 + +### 🐛 Bug Fixes + +- Fix docs build (#633) + + +- Fix typo in the `ORTModel.inputs_names` field to align with upstream (#866) + +### 📚 Documentation + +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/optimum-v0.1.0] - 2024-03-04 + +### 🚀 Features + +- Add Optimum Embedders (#379) +- [**breaking**] Add support for Optimum optimizers and quantizers (#496) +- Add example for Optimum integration, fix docs, CI (#526) + +### 🚜 Refactor + +- [**breaking**] Simplify Optimum backend impl (#477) + +### 📚 Documentation + +- Fix Optimum embedder examples (#517) + + diff --git a/integrations/optimum/pydoc/config.yml b/integrations/optimum/pydoc/config.yml index 8597b07ad..62edc9502 100644 --- a/integrations/optimum/pydoc/config.yml +++ b/integrations/optimum/pydoc/config.yml @@ -19,7 +19,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Optimum integration for Haystack category_slug: integrations-api title: Optimum diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index a0ed15f92..2e0fb26a4 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -57,14 +57,18 @@ git_describe_command = 'git describe --tags --match="integrations/optimum-v[0-9] dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "haystack-pydoc-tools", - "databind-core<4.5.0", "setuptools", # FIXME: the latest 4.5.0 causes loops in pip resolver + "databind-core<4.5.0", + "setuptools", # FIXME: the latest 4.5.0 causes loops in pip resolver ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -77,7 +81,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -161,12 +165,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py index a6d226ecc..5a9e1cf1f 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py @@ -173,7 +173,7 @@ def _tokenize_and_generate_outputs(self, texts: List[str]) -> Tuple[Dict[str, An tokenizer_outputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to( self.model.device ) - model_inputs = {k: v for k, v in tokenizer_outputs.items() if k in self.model.inputs_names} + model_inputs = {k: v for k, v in tokenizer_outputs.items() if k in self.model.input_names} model_outputs = self.model(**model_inputs) return tokenizer_outputs, model_outputs diff --git a/integrations/pgvector/CHANGELOG.md b/integrations/pgvector/CHANGELOG.md new file mode 100644 index 000000000..deb6faece --- /dev/null +++ b/integrations/pgvector/CHANGELOG.md @@ -0,0 +1,51 @@ +# Changelog + +## [unreleased] + +### 🚀 Features + +- Add filter_policy to pgvector integration (#820) + +### 🐛 Bug Fixes + +- `PgVector` - Fallback to default filter policy when deserializing retrievers without the init parameter (#900) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/pgvector-v0.4.0] - 2024-06-20 + +### 🚀 Features + +- Defer the database connection to when it's needed (#773) +- Add customizable index names for pgvector (#818) + +## [integrations/pgvector-v0.2.0] - 2024-05-08 + +### 🚀 Features + +- `MongoDBAtlasEmbeddingRetriever` (#427) +- Implement keyword retrieval for pgvector integration (#644) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Disable-class-def (#556) + +## [integrations/pgvector-v0.1.0] - 2024-02-14 + +### 🐛 Bug Fixes + +- Fix linting (#328) + + + + diff --git a/integrations/pgvector/examples/example.py b/integrations/pgvector/examples/embedding_retrieval.py similarity index 100% rename from integrations/pgvector/examples/example.py rename to integrations/pgvector/examples/embedding_retrieval.py diff --git a/integrations/pgvector/examples/hybrid_retrieval.py b/integrations/pgvector/examples/hybrid_retrieval.py new file mode 100644 index 000000000..cee98fe08 --- /dev/null +++ b/integrations/pgvector/examples/hybrid_retrieval.py @@ -0,0 +1,69 @@ +# Before running this example, ensure you have PostgreSQL installed with the pgvector extension. +# For a quick setup using Docker: +# docker run -d -p 5432:5432 -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres +# -e POSTGRES_DB=postgres ankane/pgvector + +# Install required packages for this example, including pgvector-haystack and other libraries needed +# for Markdown conversion and embeddings generation. Use the following command: +# pip install pgvector-haystack markdown-it-py mdit_plain "sentence-transformers>=2.2.0" + +# Download some Markdown files to index. +# git clone https://github.com/anakin87/neural-search-pills + +import glob + +from haystack import Pipeline +from haystack.components.converters import MarkdownToDocument +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.joiners import DocumentJoiner +from haystack.components.preprocessors import DocumentSplitter +from haystack.components.writers import DocumentWriter +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + +# Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. +# e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + +# Initialize PgvectorDocumentStore +document_store = PgvectorDocumentStore( + table_name="haystack_test", + embedding_dimension=768, + vector_function="cosine_similarity", + recreate_table=True, + search_strategy="hnsw", +) + +# Create the indexing Pipeline and index some documents +file_paths = glob.glob("neural-search-pills/pills/*.md") + + +indexing = Pipeline() +indexing.add_component("converter", MarkdownToDocument()) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) +indexing.add_component("document_embedder", SentenceTransformersDocumentEmbedder()) +indexing.add_component("writer", DocumentWriter(document_store)) +indexing.connect("converter", "splitter") +indexing.connect("splitter", "document_embedder") +indexing.connect("document_embedder", "writer") + +indexing.run({"converter": {"sources": file_paths}}) + +# Create the querying Pipeline and try a query +querying = Pipeline() +querying.add_component("text_embedder", SentenceTransformersTextEmbedder()) +querying.add_component("retriever", PgvectorEmbeddingRetriever(document_store=document_store, top_k=3)) +querying.add_component("keyword_retriever", PgvectorKeywordRetriever(document_store=document_store, top_k=3)) +querying.add_component( + "joiner", + DocumentJoiner(join_mode="reciprocal_rank_fusion", top_k=3), +) +querying.connect("text_embedder", "retriever") +querying.connect("keyword_retriever", "joiner") +querying.connect("retriever", "joiner") + +query = "cross-encoder" +results = querying.run({"text_embedder": {"text": query}, "keyword_retriever": {"query": query}}) + +for doc in results["joiner"]["documents"]: + print(doc) + print("-" * 10) diff --git a/integrations/pgvector/pydoc/config.yml b/integrations/pgvector/pydoc/config.yml index 1be4a1662..11be3aa4a 100644 --- a/integrations/pgvector/pydoc/config.yml +++ b/integrations/pgvector/pydoc/config.yml @@ -3,6 +3,7 @@ loaders: search_path: [../src] modules: [ "haystack_integrations.components.retrievers.pgvector.embedding_retriever", + "haystack_integrations.components.retrievers.pgvector.keyword_retriever", "haystack_integrations.document_stores.pgvector.document_store", ] ignore_when_discovered: ["__init__"] @@ -15,7 +16,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Pgvector integration for Haystack category_slug: integrations-api title: Pgvector diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 39e2183cb..7f31a5203 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -25,11 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "pgvector", - "psycopg[binary]" -] +dependencies = ["haystack-ai", "pgvector", "psycopg[binary]"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -51,49 +45,30 @@ git_describe_command = 'git describe --tags --match="integrations/pgvector-v[0-9 dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "ipython", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py38"] @@ -136,9 +111,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", # ignore function-call-in-default-argument "B008", ] @@ -167,12 +148,11 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[tool.pytest.ini_options] +markers = ["integration: integration tests"] [[tool.mypy.overrides]] @@ -181,6 +161,6 @@ module = [ "haystack_integrations.*", "pgvector.*", "psycopg.*", - "pytest.*" + "pytest.*", ] ignore_missing_imports = true diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py index ec0cf0dc4..ea9fa8fe7 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py @@ -2,5 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .embedding_retriever import PgvectorEmbeddingRetriever +from .keyword_retriever import PgvectorKeywordRetriever -__all__ = ["PgvectorEmbeddingRetriever"] +__all__ = ["PgvectorEmbeddingRetriever", "PgvectorKeywordRetriever"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py index 6085545cb..22aab1a73 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS @@ -62,9 +64,10 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ - :param document_store: An instance of `PgvectorDocumentStore}. + :param document_store: An instance of `PgvectorDocumentStore`. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. :param vector_function: The similarity function to use when searching for similar embeddings. @@ -75,7 +78,7 @@ def __init__( and the most similar documents are the ones with the smallest score. **Important**: if the document store is using the `"hnsw"` search strategy, the vector function should match the one utilized during index creation to take advantage of the index. - + :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of `PgvectorDocumentStore` or if `vector_function` is not one of the valid options. """ @@ -91,6 +94,9 @@ def __init__( self.filters = filters or {} self.top_k = top_k self.vector_function = vector_function or document_store.vector_function + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) def to_dict(self) -> Dict[str, Any]: """ @@ -104,6 +110,7 @@ def to_dict(self) -> Dict[str, Any]: filters=self.filters, top_k=self.top_k, vector_function=self.vector_function, + filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -119,6 +126,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": """ doc_store_params = data["init_parameters"]["document_store"] data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -133,13 +144,15 @@ def run( Retrieve documents from the `PgvectorDocumentStore`, based on their embeddings. :param query_embedding: Embedding of the query. - :param filters: Filters applied to the retrieved Documents. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: Maximum number of Documents to return. :param vector_function: The similarity function to use when searching for similar embeddings. :returns: List of Documents similar to `query_embedding`. """ - filters = filters or self.filters + filters = apply_filter_policy(self.filter_policy, self.filters, filters) top_k = top_k or self.top_k vector_function = vector_function or self.vector_function diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py new file mode 100644 index 000000000..636471c31 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@component +class PgvectorKeywordRetriever: + """ + Retrieve documents from the `PgvectorDocumentStore`, based on keywords. + + To rank the documents, the `ts_rank_cd` function of PostgreSQL is used. + It considers how often the query terms appear in the document, how close together the terms are in the document, + and how important is the part of the document where they occur. + For more details, see + [Postgres documentation](https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-RANKING). + + Usage example: + ```python + from haystack.document_stores import DuplicatePolicy + from haystack import Document + + from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + from haystack_integrations.components.retrievers.pgvector import PgvectorKeywordRetriever + + # Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. + # e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + + document_store = PgvectorDocumentStore(language="english", recreate_table=True) + + documents = [Document(content="There are over 7,000 languages spoken around the world today."), + Document(content="Elephants have been observed to behave in a way that indicates..."), + Document(content="In certain places, you can witness the phenomenon of bioluminescent waves.")] + + document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE) + + retriever = PgvectorKeywordRetriever(document_store=document_store) + + result = retriever.run(query="languages") + + assert res['retriever']['documents'][0].content == "There are over 7,000 languages spoken around the world today." + """ + + def __init__( + self, + *, + document_store: PgvectorDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + :param document_store: An instance of `PgvectorDocumentStore`. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. + :raises ValueError: If `document_store` is not an instance of `PgvectorDocumentStore`. + """ + if not isinstance(document_store, PgvectorDocumentStore): + msg = "document_store must be an instance of PgvectorDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + filter_policy=self.filter_policy.value, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + doc_store_params = data["init_parameters"]["document_store"] + data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ): + """ + Retrieve documents from the `PgvectorDocumentStore`, based on keywords. + + :param query: String to search in `Document`s' content. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: Maximum number of Documents to return. + + :returns: A dictionary with the following keys: + - `documents`: List of `Document`s that match the query. + """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + + top_k = top_k or self.top_k + + docs = self.document_store._keyword_retrieval( + query=query, + filters=filters, + top_k=top_k, + ) + return {"documents": docs} diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index da08a5f19..ae4878aba 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -53,6 +53,12 @@ meta = EXCLUDED.meta """ +KEYWORD_QUERY = """ +SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score +FROM {table_name}, plainto_tsquery({language}, %s) query +WHERE to_tsvector({language}, content) @@ query +""" + VALID_VECTOR_FUNCTIONS = ["cosine_similarity", "inner_product", "l2_distance"] VECTOR_FUNCTION_TO_POSTGRESQL_OPS = { @@ -63,8 +69,6 @@ HNSW_INDEX_CREATION_VALID_KWARGS = ["m", "ef_construction"] -HNSW_INDEX_NAME = "haystack_hnsw_index" - class PgvectorDocumentStore: """ @@ -76,13 +80,16 @@ def __init__( *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), table_name: str = "haystack_documents", + language: str = "english", embedding_dimension: int = 768, vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", recreate_table: bool = False, search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", hnsw_recreate_index_if_exists: bool = False, hnsw_index_creation_kwargs: Optional[Dict[str, int]] = None, + hnsw_index_name: str = "haystack_hnsw_index", hnsw_ef_search: Optional[int] = None, + keyword_index_name: str = "haystack_keyword_index", ): """ Creates a new PgvectorDocumentStore instance. @@ -92,6 +99,10 @@ def __init__( :param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an environment variable, e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"` :param table_name: The name of the table to use to store Haystack documents. + :param language: The language to be used to parse query and document content in keyword retrieval. + To see the list of available languages, you can run the following SQL query in your PostgreSQL database: + `SELECT cfgname FROM pg_ts_config;`. + More information can be found in this [StackOverflow answer](https://stackoverflow.com/a/39752553). :param embedding_dimension: The dimension of the embedding. :param vector_function: The similarity function to use when searching for similar embeddings. `"cosine_similarity"` and `"inner_product"` are similarity functions and @@ -114,9 +125,11 @@ def __init__( :param hnsw_index_creation_kwargs: Additional keyword arguments to pass to the HNSW index creation. Only used if search_strategy is set to `"hnsw"`. You can find the list of valid arguments in the [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) + :param hnsw_index_name: Index name for the HNSW index. :param hnsw_ef_search: The `ef_search` parameter to use at query time. Only used if search_strategy is set to `"hnsw"`. You can find more information about this parameter in the - [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) + [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw). + :param keyword_index_name: Index name for the Keyword index. """ self.connection_string = connection_string @@ -130,26 +143,57 @@ def __init__( self.search_strategy = search_strategy self.hnsw_recreate_index_if_exists = hnsw_recreate_index_if_exists self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {} + self.hnsw_index_name = hnsw_index_name self.hnsw_ef_search = hnsw_ef_search + self.keyword_index_name = keyword_index_name + self.language = language + self._connection = None + self._cursor = None + self._dict_cursor = None - connection = connect(self.connection_string.resolve_value()) - connection.autocommit = True - self._connection = connection + @property + def cursor(self): + if self._cursor is None: + self._create_connection() + + return self._cursor + + @property + def dict_cursor(self): + if self._dict_cursor is None: + self._create_connection() - # we create a generic cursor and another one that returns dictionaries - self._cursor = connection.cursor() - self._dict_cursor = connection.cursor(row_factory=dict_row) + return self._dict_cursor + @property + def connection(self): + if self._connection is None: + self._create_connection() + + return self._connection + + def _create_connection(self): + conn_str = self.connection_string.resolve_value() or "" + connection = connect(conn_str) + connection.autocommit = True connection.execute("CREATE EXTENSION IF NOT EXISTS vector") - register_vector(connection) + register_vector(connection) # Note: this must be called before creating the cursors. - if recreate_table: + self._connection = connection + self._cursor = self._connection.cursor() + self._dict_cursor = self._connection.cursor(row_factory=dict_row) + + # Init schema + if self.recreate_table: self.delete_table() self._create_table_if_not_exists() + self._create_keyword_index_if_not_exists() - if search_strategy == "hnsw": + if self.search_strategy == "hnsw": self._handle_hnsw() + return self._connection + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -167,7 +211,10 @@ def to_dict(self) -> Dict[str, Any]: search_strategy=self.search_strategy, hnsw_recreate_index_if_exists=self.hnsw_recreate_index_if_exists, hnsw_index_creation_kwargs=self.hnsw_index_creation_kwargs, + hnsw_index_name=self.hnsw_index_name, hnsw_ef_search=self.hnsw_ef_search, + keyword_index_name=self.keyword_index_name, + language=self.language, ) @classmethod @@ -192,11 +239,11 @@ def _execute_sql( :param sql_query: The SQL query to execute. :param params: The parameters to pass to the SQL query. :param error_msg: The error message to use if an exception is raised. - :param cursor: The cursor to use to execute the SQL query. Defaults to self._cursor. + :param cursor: The cursor to use to execute the SQL query. Defaults to self.cursor. """ params = params or () - cursor = cursor or self._cursor + cursor = cursor or self.cursor sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params) @@ -204,7 +251,7 @@ def _execute_sql( try: result = cursor.execute(sql_query, params) except Error as e: - self._connection.rollback() + self.connection.rollback() detailed_error_msg = f"{error_msg}.\nYou can find the SQL query and the parameters in the debug logs." raise DocumentStoreError(detailed_error_msg) from e @@ -231,6 +278,29 @@ def delete_table(self): self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + def _create_keyword_index_if_not_exists(self): + """ + Internal method to create the keyword index if not exists. + """ + index_exists = bool( + self._execute_sql( + "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", + (self.table_name, self.keyword_index_name), + "Could not check if keyword index exists", + ).fetchone() + ) + + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + ).format( + index_name=Identifier(self.keyword_index_name), + table_name=Identifier(self.table_name), + language=SQLLiteral(self.language), + ) + + if not index_exists: + self._execute_sql(sql_create_index, error_msg="Could not create keyword index on table") + def _handle_hnsw(self): """ Internal method to handle the HNSW index creation. @@ -246,7 +316,7 @@ def _handle_hnsw(self): index_exists = bool( self._execute_sql( "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, HNSW_INDEX_NAME), + (self.table_name, self.hnsw_index_name), "Could not check if HNSW index exists", ).fetchone() ) @@ -259,7 +329,7 @@ def _handle_hnsw(self): ) return - sql_drop_index = SQL("DROP INDEX IF EXISTS {index_name}").format(index_name=Identifier(HNSW_INDEX_NAME)) + sql_drop_index = SQL("DROP INDEX IF EXISTS {index_name}").format(index_name=Identifier(self.hnsw_index_name)) self._execute_sql(sql_drop_index, error_msg="Could not drop HNSW index") self._create_hnsw_index() @@ -277,7 +347,7 @@ def _create_hnsw_index(self): } sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( - index_name=Identifier(HNSW_INDEX_NAME), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) ) if actual_hnsw_index_creation_kwargs: @@ -332,7 +402,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc sql_filter, params, error_msg="Could not filter documents from PgvectorDocumentStore.", - cursor=self._dict_cursor, + cursor=self.dict_cursor, ) records = result.fetchall() @@ -369,16 +439,16 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D sql_insert += SQL(" RETURNING id") - sql_query_str = sql_insert.as_string(self._cursor) if not isinstance(sql_insert, str) else sql_insert + sql_query_str = sql_insert.as_string(self.cursor) if not isinstance(sql_insert, str) else sql_insert logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents) try: - self._cursor.executemany(sql_insert, db_documents, returning=True) + self.cursor.executemany(sql_insert, db_documents, returning=True) except IntegrityError as ie: - self._connection.rollback() + self.connection.rollback() raise DuplicateDocumentError from ie except Error as e: - self._connection.rollback() + self.connection.rollback() error_msg = ( "Could not write documents to PgvectorDocumentStore. \n" "You can find the SQL query and the parameters in the debug logs." @@ -389,9 +459,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D # https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.executemany written_docs = 0 while True: - if self._cursor.fetchone(): + if self.cursor.fetchone(): written_docs += 1 - if not self._cursor.nextset(): + if not self.cursor.nextset(): break return written_docs @@ -444,8 +514,8 @@ def _from_pg_to_haystack_documents(documents: List[Dict[str, Any]]) -> List[Docu # postgresql returns the embedding as a string # so we need to convert it to a list of floats - if document.get("embedding"): - haystack_dict["embedding"] = [float(el) for el in document["embedding"].strip("[]").split(",")] + if document.get("embedding") is not None: + haystack_dict["embedding"] = document["embedding"].tolist() haystack_document = Document.from_dict(haystack_dict) @@ -475,6 +545,54 @@ def delete_documents(self, document_ids: List[str]) -> None: self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") + def _keyword_retrieval( + self, + query: str, + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query using a full-text search. + + This method is not meant to be part of the public interface of + `PgvectorDocumentStore` and it should not be called directly. + `PgvectorKeywordRetriever` uses this method directly and is the public interface for it. + + :returns: List of Documents that are most similar to `query` + """ + if not query: + msg = "query must be a non-empty string" + raise ValueError(msg) + + sql_select = SQL(KEYWORD_QUERY).format( + table_name=Identifier(self.table_name), + language=SQLLiteral(self.language), + query=SQLLiteral(query), + ) + + where_params = () + sql_where_clause = SQL("") + if filters: + sql_where_clause, where_params = _convert_filters_to_where_clause_and_params( + filters=filters, operator="AND" + ) + + sql_sort = SQL(" ORDER BY score DESC LIMIT {top_k}").format(top_k=SQLLiteral(top_k)) + + sql_query = sql_select + sql_where_clause + sql_sort + + result = self._execute_sql( + sql_query, + (query, *where_params), + error_msg="Could not retrieve documents from PgvectorDocumentStore.", + cursor=self.dict_cursor, + ) + + records = result.fetchall() + docs = self._from_pg_to_haystack_documents(records) + return docs + def _embedding_retrieval( self, query_embedding: List[float], @@ -489,6 +607,7 @@ def _embedding_retrieval( This method is not meant to be part of the public interface of `PgvectorDocumentStore` and it should not be called directly. `PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it. + :returns: List of Documents that are most similar to `query_embedding` """ @@ -545,7 +664,7 @@ def _embedding_retrieval( sql_query, params, error_msg="Could not retrieve documents from PgvectorDocumentStore.", - cursor=self._dict_cursor, + cursor=self.dict_cursor, ) records = result.fetchall() diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py index daa90f502..d3604cfb3 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from datetime import datetime from itertools import chain -from typing import Any, Dict, List +from typing import Any, Dict, List, Literal, Tuple from haystack.errors import FilterError from pandas import DataFrame @@ -22,7 +22,9 @@ NO_VALUE = "no_value" -def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]: +def _convert_filters_to_where_clause_and_params( + filters: Dict[str, Any], operator: Literal["WHERE", "AND"] = "WHERE" +) -> Tuple[SQL, Tuple]: """ Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL. """ @@ -31,13 +33,13 @@ def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tupl else: query, values = _parse_logical_condition(filters) - where_clause = SQL(" WHERE ") + SQL(query) + where_clause = SQL(f" {operator} ") + SQL(query) params = tuple(value for value in values if value != NO_VALUE) return where_clause, params -def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: +def _parse_logical_condition(condition: Dict[str, Any]) -> Tuple[str, List[Any]]: if "operator" not in condition: msg = f"'operator' key missing in {condition}" raise FilterError(msg) @@ -77,7 +79,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]] return sql_query, values -def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: +def _parse_comparison_condition(condition: Dict[str, Any]) -> Tuple[str, List[Any]]: field: str = condition["field"] if "operator" not in condition: msg = f"'operator' key missing in {condition}" @@ -132,20 +134,20 @@ def _treat_meta_field(field: str, value: Any) -> str: return field -def _equal(field: str, value: Any) -> tuple[str, Any]: +def _equal(field: str, value: Any) -> Tuple[str, Any]: if value is None: # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params return f"{field} IS NULL", NO_VALUE return f"{field} = %s", value -def _not_equal(field: str, value: Any) -> tuple[str, Any]: +def _not_equal(field: str, value: Any) -> Tuple[str, Any]: # we use IS DISTINCT FROM to correctly handle NULL values # (not handled by !=) return f"{field} IS DISTINCT FROM %s", value -def _greater_than(field: str, value: Any) -> tuple[str, Any]: +def _greater_than(field: str, value: Any) -> Tuple[str, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -162,7 +164,7 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]: return f"{field} > %s", value -def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: +def _greater_than_equal(field: str, value: Any) -> Tuple[str, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -179,7 +181,7 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: return f"{field} >= %s", value -def _less_than(field: str, value: Any) -> tuple[str, Any]: +def _less_than(field: str, value: Any) -> Tuple[str, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -196,7 +198,7 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]: return f"{field} < %s", value -def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: +def _less_than_equal(field: str, value: Any) -> Tuple[str, Any]: if isinstance(value, str): try: datetime.fromisoformat(value) @@ -213,7 +215,7 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: return f"{field} <= %s", value -def _not_in(field: str, value: Any) -> tuple[str, List]: +def _not_in(field: str, value: Any) -> Tuple[str, List]: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" raise FilterError(msg) @@ -221,7 +223,7 @@ def _not_in(field: str, value: Any) -> tuple[str, List]: return f"{field} IS NULL OR {field} != ALL(%s)", [value] -def _in(field: str, value: Any) -> tuple[str, List]: +def _in(field: str, value: Any) -> Tuple[str, List]: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" raise FilterError(msg) diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index 94b35a04d..b53589763 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -36,10 +36,11 @@ def patches_for_unit_tests(): ) as mock_delete, patch( "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_table_if_not_exists" ) as mock_create, patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._create_keyword_index_if_not_exists" + ) as mock_create_kw_index, patch( "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore._handle_hnsw" ) as mock_hnsw: - - yield mock_connect, mock_register, mock_delete, mock_create, mock_hnsw + yield mock_connect, mock_register, mock_delete, mock_create, mock_create_kw_index, mock_hnsw @pytest.fixture diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index bf5ccd5d4..eca8190ee 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -4,11 +4,13 @@ from unittest.mock import patch +import numpy as np import pytest from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack.utils import Secret from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore from pandas import DataFrame @@ -51,7 +53,9 @@ def test_init(monkeypatch): search_strategy="hnsw", hnsw_recreate_index_if_exists=True, hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, + hnsw_index_name="my_hnsw_index", hnsw_ef_search=50, + keyword_index_name="my_keyword_index", ) assert document_store.table_name == "my_table" @@ -61,7 +65,9 @@ def test_init(monkeypatch): assert document_store.search_strategy == "hnsw" assert document_store.hnsw_recreate_index_if_exists assert document_store.hnsw_index_creation_kwargs == {"m": 32, "ef_construction": 128} + assert document_store.hnsw_index_name == "my_hnsw_index" assert document_store.hnsw_ef_search == 50 + assert document_store.keyword_index_name == "my_keyword_index" @pytest.mark.usefixtures("patches_for_unit_tests") @@ -76,7 +82,9 @@ def test_to_dict(monkeypatch): search_strategy="hnsw", hnsw_recreate_index_if_exists=True, hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, + hnsw_index_name="my_hnsw_index", hnsw_ef_search=50, + keyword_index_name="my_keyword_index", ) assert document_store.to_dict() == { @@ -89,8 +97,11 @@ def test_to_dict(monkeypatch): "recreate_table": True, "search_strategy": "hnsw", "hnsw_recreate_index_if_exists": True, + "language": "english", "hnsw_index_creation_kwargs": {"m": 32, "ef_construction": 128}, + "hnsw_index_name": "my_hnsw_index", "hnsw_ef_search": 50, + "keyword_index_name": "my_keyword_index", }, } @@ -169,7 +180,7 @@ def test_from_pg_to_haystack_documents(): "blob_meta": None, "blob_mime_type": None, "meta": {"meta_key": "meta_value"}, - "embedding": "[0.1, 0.2, 0.3]", + "embedding": np.array([0.1, 0.2, 0.3]), }, { "id": "2", @@ -179,7 +190,7 @@ def test_from_pg_to_haystack_documents(): "blob_meta": None, "blob_mime_type": None, "meta": {"meta_key": "meta_value"}, - "embedding": "[0.4, 0.5, 0.6]", + "embedding": np.array([0.4, 0.5, 0.6]), }, { "id": "3", @@ -189,16 +200,11 @@ def test_from_pg_to_haystack_documents(): "blob_meta": {"blob_meta_key": "blob_meta_value"}, "blob_mime_type": "mime_type", "meta": {"meta_key": "meta_value"}, - "embedding": "[0.7, 0.8, 0.9]", + "embedding": np.array([0.7, 0.8, 0.9]), }, ] - with patch( - "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" - ) as mock_init: - mock_init.return_value = None - ds = PgvectorDocumentStore(connection_string="test") - + ds = PgvectorDocumentStore(connection_string=Secret.from_token("test")) haystack_docs = ds._from_pg_to_haystack_documents(pg_docs) assert haystack_docs[0].id == "1" diff --git a/integrations/pgvector/tests/test_keyword_retrieval.py b/integrations/pgvector/tests/test_keyword_retrieval.py new file mode 100644 index 000000000..4a5614165 --- /dev/null +++ b/integrations/pgvector/tests/test_keyword_retrieval.py @@ -0,0 +1,50 @@ +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@pytest.mark.integration +class TestKeywordRetrieval: + def test_keyword_retrieval(self, document_store: PgvectorDocumentStore): + docs = [ + Document(content="The quick brown fox chased the dog", embedding=[0.1] * 768), + Document(content="The fox was brown", embedding=[0.1] * 768), + Document(content="The lazy dog", embedding=[0.1] * 768), + Document(content="fox fox fox", embedding=[0.1] * 768), + ] + + document_store.write_documents(docs) + + results = document_store._keyword_retrieval(query="fox", top_k=2) + + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].id == docs[-1].id + assert results[0].score > results[1].score + + def test_keyword_retrieval_with_filters(self, document_store: PgvectorDocumentStore): + docs = [ + Document( + content="The quick brown fox chased the dog", + embedding=([0.1] * 768), + meta={"meta_field": "right_value"}, + ), + Document(content="The fox was brown", embedding=([0.1] * 768), meta={"meta_field": "right_value"}), + Document(content="The lazy dog", embedding=([0.1] * 768), meta={"meta_field": "right_value"}), + Document(content="fox fox fox", embedding=([0.1] * 768), meta={"meta_field": "wrong_value"}), + ] + + document_store.write_documents(docs) + + filters = {"field": "meta.meta_field", "operator": "==", "value": "right_value"} + + results = document_store._keyword_retrieval(query="fox", top_k=3, filters=filters) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert doc.meta["meta_field"] == "right_value" + + def test_empty_query(self, document_store: PgvectorDocumentStore): + with pytest.raises(ValueError): + document_store._keyword_retrieval(query="") diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py deleted file mode 100644 index 61381c24e..000000000 --- a/integrations/pgvector/tests/test_retriever.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from unittest.mock import Mock - -import pytest -from haystack.dataclasses import Document -from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever -from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore - - -class TestRetriever: - def test_init_default(self, mock_store): - retriever = PgvectorEmbeddingRetriever(document_store=mock_store) - assert retriever.document_store == mock_store - assert retriever.filters == {} - assert retriever.top_k == 10 - assert retriever.vector_function == mock_store.vector_function - - def test_init(self, mock_store): - retriever = PgvectorEmbeddingRetriever( - document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" - ) - assert retriever.document_store == mock_store - assert retriever.filters == {"field": "value"} - assert retriever.top_k == 5 - assert retriever.vector_function == "l2_distance" - - def test_to_dict(self, mock_store): - retriever = PgvectorEmbeddingRetriever( - document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" - ) - res = retriever.to_dict() - t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" - assert res == { - "type": t, - "init_parameters": { - "document_store": { - "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", - "init_parameters": { - "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, - "table_name": "haystack", - "embedding_dimension": 768, - "vector_function": "cosine_similarity", - "recreate_table": True, - "search_strategy": "exact_nearest_neighbor", - "hnsw_recreate_index_if_exists": False, - "hnsw_index_creation_kwargs": {}, - "hnsw_ef_search": None, - }, - }, - "filters": {"field": "value"}, - "top_k": 5, - "vector_function": "l2_distance", - }, - } - - @pytest.mark.usefixtures("patches_for_unit_tests") - def test_from_dict(self, monkeypatch): - monkeypatch.setenv("PG_CONN_STR", "some-connection-string") - t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" - data = { - "type": t, - "init_parameters": { - "document_store": { - "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", - "init_parameters": { - "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, - "table_name": "haystack_test_to_dict", - "embedding_dimension": 768, - "vector_function": "cosine_similarity", - "recreate_table": True, - "search_strategy": "exact_nearest_neighbor", - "hnsw_recreate_index_if_exists": False, - "hnsw_index_creation_kwargs": {}, - "hnsw_ef_search": None, - }, - }, - "filters": {"field": "value"}, - "top_k": 5, - "vector_function": "l2_distance", - }, - } - - retriever = PgvectorEmbeddingRetriever.from_dict(data) - document_store = retriever.document_store - - assert isinstance(document_store, PgvectorDocumentStore) - assert isinstance(document_store.connection_string, EnvVarSecret) - assert document_store.table_name == "haystack_test_to_dict" - assert document_store.embedding_dimension == 768 - assert document_store.vector_function == "cosine_similarity" - assert document_store.recreate_table - assert document_store.search_strategy == "exact_nearest_neighbor" - assert not document_store.hnsw_recreate_index_if_exists - assert document_store.hnsw_index_creation_kwargs == {} - assert document_store.hnsw_ef_search is None - - assert retriever.filters == {"field": "value"} - assert retriever.top_k == 5 - assert retriever.vector_function == "l2_distance" - - def test_run(self): - mock_store = Mock(spec=PgvectorDocumentStore) - doc = Document(content="Test doc", embedding=[0.1, 0.2]) - mock_store._embedding_retrieval.return_value = [doc] - - retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") - res = retriever.run(query_embedding=[0.3, 0.5]) - - mock_store._embedding_retrieval.assert_called_once_with( - query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" - ) - - assert res == {"documents": [doc]} diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py new file mode 100644 index 000000000..031c735fd --- /dev/null +++ b/integrations/pgvector/tests/test_retrievers.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.utils.auth import EnvVarSecret +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever, PgvectorKeywordRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +class TestEmbeddingRetriever: + def test_init_default(self, mock_store): + retriever = PgvectorEmbeddingRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + assert retriever.vector_function == mock_store.vector_function + + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, filter_policy="merge") + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + PgvectorEmbeddingRetriever(document_store=mock_store, filter_policy="invalid") + + def test_init(self, mock_store): + retriever = PgvectorEmbeddingRetriever( + document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + assert retriever.vector_function == "l2_distance" + + def test_to_dict(self, mock_store): + retriever = PgvectorEmbeddingRetriever( + document_store=mock_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "language": "english", + "hnsw_index_creation_kwargs": {}, + "hnsw_index_name": "haystack_hnsw_index", + "hnsw_ef_search": None, + "keyword_index_name": "haystack_keyword_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + "filter_policy": "replace", + }, + } + + @pytest.mark.usefixtures("patches_for_unit_tests") + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_index_name": "haystack_hnsw_index", + "hnsw_ef_search": None, + "keyword_index_name": "haystack_keyword_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + "filter_policy": "replace", + }, + } + + retriever = PgvectorEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert isinstance(document_store.connection_string, EnvVarSecret) + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_index_name == "haystack_hnsw_index" + assert document_store.hnsw_ef_search is None + assert document_store.keyword_index_name == "haystack_keyword_index" + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + assert retriever.vector_function == "l2_distance" + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._embedding_retrieval.return_value = [doc] + + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") + res = retriever.run(query_embedding=[0.3, 0.5]) + + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" + ) + + assert res == {"documents": [doc]} + + +class TestKeywordRetriever: + def test_init_default(self, mock_store): + retriever = PgvectorKeywordRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + + retriever = PgvectorKeywordRetriever(document_store=mock_store, filter_policy="merge") + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + PgvectorKeywordRetriever(document_store=mock_store, filter_policy="invalid") + + def test_init(self, mock_store): + retriever = PgvectorKeywordRetriever(document_store=mock_store, filters={"field": "value"}, top_k=5) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + + def test_init_with_filter_policy(self, mock_store): + retriever = PgvectorKeywordRetriever( + document_store=mock_store, filters={"field": "value"}, top_k=5, filter_policy=FilterPolicy.MERGE + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.MERGE + + def test_to_dict(self, mock_store): + retriever = PgvectorKeywordRetriever(document_store=mock_store, filters={"field": "value"}, top_k=5) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "language": "english", + "hnsw_index_creation_kwargs": {}, + "hnsw_index_name": "haystack_hnsw_index", + "hnsw_ef_search": None, + "keyword_index_name": "haystack_keyword_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + @pytest.mark.usefixtures("patches_for_unit_tests") + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") + t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_index_name": "haystack_hnsw_index", + "hnsw_ef_search": None, + "keyword_index_name": "haystack_keyword_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + retriever = PgvectorKeywordRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert isinstance(document_store.connection_string, EnvVarSecret) + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_index_name == "haystack_hnsw_index" + assert document_store.hnsw_ef_search is None + assert document_store.keyword_index_name == "haystack_keyword_index" + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + + @pytest.mark.usefixtures("patches_for_unit_tests") + def test_from_dict_without_filter_policy(self, monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") + t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_index_name": "haystack_hnsw_index", + "hnsw_ef_search": None, + "keyword_index_name": "haystack_keyword_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = PgvectorKeywordRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert isinstance(document_store.connection_string, EnvVarSecret) + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_index_name == "haystack_hnsw_index" + assert document_store.hnsw_ef_search is None + assert document_store.keyword_index_name == "haystack_keyword_index" + + assert retriever.filters == {"field": "value"} + assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE + assert retriever.top_k == 5 + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._keyword_retrieval.return_value = [doc] + + retriever = PgvectorKeywordRetriever(document_store=mock_store) + res = retriever.run(query="test query") + + mock_store._keyword_retrieval.assert_called_once_with(query="test query", filters={}, top_k=10) + + assert res == {"documents": [doc]} + + def test_run_with_filters(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._keyword_retrieval.return_value = [doc] + + retriever = PgvectorKeywordRetriever( + document_store=mock_store, filter_policy=FilterPolicy.MERGE, filters={"field": "value"} + ) + res = retriever.run(query="test query", filters={"field2": "value2"}) + + mock_store._keyword_retrieval.assert_called_once_with( + query="test query", filters={"field": "value", "field2": "value2"}, top_k=10 + ) + + assert res == {"documents": [doc]} diff --git a/integrations/pinecone/CHANGELOG.md b/integrations/pinecone/CHANGELOG.md new file mode 100644 index 000000000..317753192 --- /dev/null +++ b/integrations/pinecone/CHANGELOG.md @@ -0,0 +1,54 @@ +# Changelog + +## [integrations/pinecone-v1.1.0] - 2024-06-11 + +### 🚀 Features + +- Defer the database connection to when it's needed (#804) + +## [integrations/pinecone-v1.0.0] - 2024-06-10 + +### 🚀 Features + +- [**breaking**] Pinecone - support for the new API (#793) + +## [integrations/pinecone-v0.4.1] - 2024-04-02 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Correctly deserialize Pinecone docstore in embedding retriever (#636) + +### 📚 Documentation + +- Update category slug (#442) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Generate API docs for Pinecone (#359) + +### Pinecone + +- Rename retriever (#396) +- Fix imports in example (#398) +- Temporarily skip failing tests (#593) +- Skip `test_write_documents_duplicate_overwrite` test (#637) + +## [integrations/pinecone-v0.2.1] - 2024-01-31 + +### 🐛 Bug Fixes + +- Fix linter (#281) + + + +## [integrations/pinecone-v0.2.0] - 2024-01-23 + +## [integrations/pinecone-v0.1.0] - 2024-01-17 + +## [integrations/pinecone-v0.0.1] - 2023-12-22 + + diff --git a/integrations/pinecone/examples/example.py b/integrations/pinecone/examples/example.py index 71d289ef6..5f7d92ce5 100644 --- a/integrations/pinecone/examples/example.py +++ b/integrations/pinecone/examples/example.py @@ -24,10 +24,10 @@ document_store = PineconeDocumentStore( api_key=Secret.from_token("YOUR-PINECONE-API-KEY"), - environment="gcp-starter", index="default", namespace="default", dimension=768, + spec={"serverless": {"region": "us-east-1", "cloud": "aws"}}, ) indexing = Pipeline() diff --git a/integrations/pinecone/pydoc/config.yml b/integrations/pinecone/pydoc/config.yml index 4265eeecc..f49ec4ab4 100644 --- a/integrations/pinecone/pydoc/config.yml +++ b/integrations/pinecone/pydoc/config.yml @@ -16,7 +16,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Pinecone integration for Haystack category_slug: integrations-api title: Pinecone diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index e59c12b31..866385dd3 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -24,8 +24,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "pinecone-client<3", # our implementation is not compatible with pinecone-client>=3 - # see https://github.com/deepset-ai/haystack-core-integrations/issues/223 + "pinecone-client>=3", # our implementation is not compatible with pinecone-client <3 ] [project.urls] @@ -48,16 +47,19 @@ git_describe_command = 'git describe --tags --match="integrations/pinecone-v[0-9 dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "pytest-xdist", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] # Pinecone tests are slow (require HTTP requests), so we run them in parallel # with pytest-xdist (https://pytest-xdist.readthedocs.io/en/stable/distribution.html) -test = "pytest -n auto --maxprocesses=2 {args:tests}" +test = "pytest -n auto --maxprocesses=2 -x {args:tests}" test-cov = "coverage run -m pytest -n auto --maxprocesses=2 {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -68,7 +70,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -151,12 +153,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] minversion = "6.0" diff --git a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py index 57e99962a..76f781f97 100644 --- a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py +++ b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.pinecone import PineconeDocumentStore @@ -55,11 +57,13 @@ def __init__( document_store: PineconeDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ :param document_store: The Pinecone Document Store. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of `PineconeDocumentStore`. """ @@ -70,6 +74,9 @@ def __init__( self.document_store = document_store self.filters = filters or {} self.top_k = top_k + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) def to_dict(self) -> Dict[str, Any]: """ @@ -81,6 +88,7 @@ def to_dict(self) -> Dict[str, Any]: self, filters=self.filters, top_k=self.top_k, + filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -96,19 +104,37 @@ def from_dict(cls, data: Dict[str, Any]) -> "PineconeEmbeddingRetriever": data["init_parameters"]["document_store"] = PineconeDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query_embedding: List[float]): + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ): """ Retrieve documents from the `PineconeDocumentStore`, based on their dense embeddings. :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: Maximum number of `Document`s to return. + :returns: List of Document similar to `query_embedding`. """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + + top_k = top_k or self.top_k + docs = self.document_store._embedding_retrieval( query_embedding=query_embedding, - filters=self.filters, - top_k=self.top_k, + filters=filters, + top_k=top_k, ) return {"documents": docs} diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index d94c54fde..1fd3adf40 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -4,7 +4,7 @@ import io import logging from copy import copy -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional import pandas as pd from haystack import default_from_dict, default_to_dict @@ -13,7 +13,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.filters import convert -import pinecone +from pinecone import Pinecone, PodSpec, ServerlessSpec from .filters import _normalize_filters @@ -25,6 +25,9 @@ TOP_K_LIMIT = 1_000 +DEFAULT_STARTER_PLAN_SPEC = {"serverless": {"region": "us-east-1", "cloud": "aws"}} + + class PineconeDocumentStore: """ A Document Store using [Pinecone vector database](https://www.pinecone.io/). @@ -34,57 +37,91 @@ def __init__( self, *, api_key: Secret = Secret.from_env_var("PINECONE_API_KEY"), # noqa: B008 - environment: str = "us-west1-gcp", index: str = "default", namespace: str = "default", batch_size: int = 100, dimension: int = 768, - **index_creation_kwargs, + spec: Optional[Dict[str, Any]] = None, + metric: Literal["cosine", "euclidean", "dotproduct"] = "cosine", ): """ Creates a new PineconeDocumentStore instance. It is meant to be connected to a Pinecone index and namespace. :param api_key: The Pinecone API key. - :param environment: The Pinecone environment to connect to. :param index: The Pinecone index to connect to. If the index does not exist, it will be created. :param namespace: The Pinecone namespace to connect to. If the namespace does not exist, it will be created at the first write. :param batch_size: The number of documents to write in a single batch. When setting this parameter, - consider [documented Pinecone limits](https://docs.pinecone.io/docs/limits). + consider [documented Pinecone limits](https://docs.pinecone.io/reference/quotas-and-limits). :param dimension: The dimension of the embeddings. This parameter is only used when creating a new index. - :param index_creation_kwargs: Additional keyword arguments to pass to the index creation method. - You can find the full list of supported arguments in the - [API reference](https://docs.pinecone.io/reference/create_index). + :param spec: The Pinecone spec to use when creating a new index. Allows choosing between serverless and pod + deployment options and setting additional parameters. Refer to the + [Pinecone documentation](https://docs.pinecone.io/reference/api/control-plane/create_index) for more + details. + If not provided, a default spec with serverless deployment in the `us-east-1` region will be used + (compatible with the free tier). + :param metric: The metric to use for similarity search. This parameter is only used when creating a new index. """ self.api_key = api_key + spec = spec or DEFAULT_STARTER_PLAN_SPEC + self.namespace = namespace + self.batch_size = batch_size + self.metric = metric + self.spec = spec + self.dimension = dimension + self.index_name = index - pinecone.init(api_key=api_key.resolve_value(), environment=environment) + self._index = None + self._dummy_vector = [-10.0] * self.dimension - if index not in pinecone.list_indexes(): - logger.info(f"Index {index} does not exist. Creating a new index.") - pinecone.create_index(name=index, dimension=dimension, **index_creation_kwargs) + @property + def index(self): + if self._index is not None: + return self._index + + client = Pinecone(api_key=self.api_key.resolve_value(), source_tag="haystack") + + if self.index_name not in client.list_indexes().names(): + logger.info(f"Index {self.index_name} does not exist. Creating a new index.") + pinecone_spec = self._convert_dict_spec_to_pinecone_object(self.spec) + client.create_index(name=self.index_name, dimension=self.dimension, spec=pinecone_spec, metric=self.metric) else: - logger.info(f"Index {index} already exists. Connecting to it.") + logger.info( + f"Connecting to existing index {self.index_name}. `dimension`, `spec`, and `metric` will be ignored." + ) - self._index = pinecone.Index(index_name=index) + self._index = client.Index(name=self.index_name) actual_dimension = self._index.describe_index_stats().get("dimension") - if actual_dimension and actual_dimension != dimension: + if actual_dimension and actual_dimension != self.dimension: logger.warning( - f"Dimension of index {index} is {actual_dimension}, but {dimension} was specified. " + f"Dimension of index {self.index_name} is {actual_dimension}, but {self.dimension} was specified. " "The specified dimension will be ignored." "If you need an index with a different dimension, please create a new one." ) - self.dimension = actual_dimension or dimension - + self.dimension = actual_dimension or self.dimension self._dummy_vector = [-10.0] * self.dimension - self.environment = environment - self.index = index - self.namespace = namespace - self.batch_size = batch_size - self.index_creation_kwargs = index_creation_kwargs + + return self._index + + @staticmethod + def _convert_dict_spec_to_pinecone_object(spec: Dict[str, Any]): + """Convert the spec dictionary to a Pinecone spec object""" + + if "serverless" in spec: + serverless_spec = spec["serverless"] + return ServerlessSpec(**serverless_spec) + if "pod" in spec: + pod_spec = spec["pod"] + return PodSpec(**pod_spec) + + msg = ( + "Invalid spec. Must contain either `serverless` or `pod` key. " + "Refer to https://docs.pinecone.io/reference/api/control-plane/create_index for more details." + ) + raise ValueError(msg) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PineconeDocumentStore": @@ -107,12 +144,12 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, api_key=self.api_key.to_dict(), - environment=self.environment, - index=self.index, + spec=self.spec, + index=self.index_name, dimension=self.dimension, namespace=self.namespace, batch_size=self.batch_size, - **self.index_creation_kwargs, + metric=self.metric, ) def count_documents(self) -> int: @@ -120,7 +157,7 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ try: - count = self._index.describe_index_stats()["namespaces"][self.namespace]["vector_count"] + count = self.index.describe_index_stats()["namespaces"][self.namespace]["vector_count"] except KeyError: count = 0 return count @@ -147,9 +184,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D documents_for_pinecone = self._convert_documents_to_pinecone_format(documents) - result = self._index.upsert( - vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size - ) + result = self.index.upsert(vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size) written_docs = result["upserted_count"] return written_docs @@ -187,7 +222,7 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: the document ids to delete """ - self._index.delete(ids=document_ids, namespace=self.namespace) + self.index.delete(ids=document_ids, namespace=self.namespace) def _embedding_retrieval( self, @@ -220,7 +255,7 @@ def _embedding_retrieval( filters = convert(filters) filters = _normalize_filters(filters) if filters else None - result = self._index.query( + result = self.index.query( vector=query_embedding, top_k=top_k, namespace=namespace or self.namespace, diff --git a/integrations/pinecone/tests/conftest.py b/integrations/pinecone/tests/conftest.py index 4a890f5d4..074ed978d 100644 --- a/integrations/pinecone/tests/conftest.py +++ b/integrations/pinecone/tests/conftest.py @@ -2,17 +2,23 @@ import pytest from haystack.document_stores.types import DuplicatePolicy -from pinecone.core.client.exceptions import NotFoundException + +try: + # pinecone-client < 5.0.0 + from pinecone.core.client.exceptions import NotFoundException +except ModuleNotFoundError: + # pinecone-client >= 5.0.0 + from pinecone.exceptions import NotFoundException from haystack_integrations.document_stores.pinecone import PineconeDocumentStore -# This is the approximate time it takes for the documents to be available -SLEEP_TIME = 30 +# This is the approximate time in seconds it takes for the documents to be available +SLEEP_TIME_IN_SECONDS = 15 @pytest.fixture() def sleep_time(): - return SLEEP_TIME + return SLEEP_TIME_IN_SECONDS @pytest.fixture @@ -21,14 +27,12 @@ def document_store(request): This is the most basic requirement for the child class: provide an instance of this document store so the base class can use it. """ - environment = "gcp-starter" index = "default" # Use a different namespace for each test so we can run them in parallel namespace = f"{request.node.name}-{int(time.time())}" dimension = 768 store = PineconeDocumentStore( - environment=environment, index=index, namespace=namespace, dimension=dimension, @@ -39,20 +43,20 @@ def document_store(request): def write_documents_and_wait(documents, policy=DuplicatePolicy.NONE): written_docs = original_write_documents(documents, policy) - time.sleep(SLEEP_TIME) + time.sleep(SLEEP_TIME_IN_SECONDS) return written_docs original_delete_documents = store.delete_documents def delete_documents_and_wait(filters): original_delete_documents(filters) - time.sleep(SLEEP_TIME) + time.sleep(SLEEP_TIME_IN_SECONDS) store.write_documents = write_documents_and_wait store.delete_documents = delete_documents_and_wait yield store try: - store._index.delete(delete_all=True, namespace=namespace) + store.index.delete(delete_all=True, namespace=namespace) except NotFoundException: pass diff --git a/integrations/pinecone/tests/test_document_store.py b/integrations/pinecone/tests/test_document_store.py index 6fac67049..90ce2ccff 100644 --- a/integrations/pinecone/tests/test_document_store.py +++ b/integrations/pinecone/tests/test_document_store.py @@ -1,4 +1,5 @@ import os +import time from unittest.mock import patch import numpy as np @@ -6,17 +7,23 @@ from haystack import Document from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest from haystack.utils import Secret +from pinecone import Pinecone, PodSpec, ServerlessSpec from haystack_integrations.document_stores.pinecone import PineconeDocumentStore -@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") +def test_init_is_lazy(_mock_client): + _ = PineconeDocumentStore(api_key=Secret.from_token("fake-api-key")) + _mock_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_init(mock_pinecone): - mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30} + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60} document_store = PineconeDocumentStore( api_key=Secret.from_token("fake-api-key"), - environment="gcp-starter", index="my_index", namespace="test", batch_size=50, @@ -24,22 +31,23 @@ def test_init(mock_pinecone): metric="euclidean", ) - mock_pinecone.init.assert_called_with(api_key="fake-api-key", environment="gcp-starter") + # Trigger an actual connection + _ = document_store.index + + mock_pinecone.assert_called_with(api_key="fake-api-key", source_tag="haystack") - assert document_store.environment == "gcp-starter" - assert document_store.index == "my_index" + assert document_store.index_name == "my_index" assert document_store.namespace == "test" assert document_store.batch_size == 50 - assert document_store.dimension == 30 - assert document_store.index_creation_kwargs == {"metric": "euclidean"} + assert document_store.dimension == 60 + assert document_store.metric == "euclidean" -@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch): monkeypatch.setenv("PINECONE_API_KEY", "env-api-key") - PineconeDocumentStore( - environment="gcp-starter", + ds = PineconeDocumentStore( index="my_index", namespace="test", batch_size=50, @@ -47,15 +55,17 @@ def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch): metric="euclidean", ) - mock_pinecone.init.assert_called_with(api_key="env-api-key", environment="gcp-starter") + # Trigger an actual connection + _ = ds.index + mock_pinecone.assert_called_with(api_key="env-api-key", source_tag="haystack") -@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") + +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_to_from_dict(mock_pinecone, monkeypatch): - mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30} + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60} monkeypatch.setenv("PINECONE_API_KEY", "env-api-key") document_store = PineconeDocumentStore( - environment="gcp-starter", index="my_index", namespace="test", batch_size=50, @@ -63,6 +73,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch): metric="euclidean", ) + # Trigger an actual connection + _ = document_store.index + dict_output = { "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", "init_parameters": { @@ -73,32 +86,100 @@ def test_to_from_dict(mock_pinecone, monkeypatch): "strict": True, "type": "env_var", }, - "environment": "gcp-starter", "index": "my_index", - "dimension": 30, + "dimension": 60, "namespace": "test", "batch_size": 50, "metric": "euclidean", + "spec": {"serverless": {"region": "us-east-1", "cloud": "aws"}}, }, } assert document_store.to_dict() == dict_output document_store = PineconeDocumentStore.from_dict(dict_output) - assert document_store.environment == "gcp-starter" assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) - assert document_store.index == "my_index" + assert document_store.index_name == "my_index" assert document_store.namespace == "test" assert document_store.batch_size == 50 - assert document_store.dimension == 30 + assert document_store.dimension == 60 + assert document_store.metric == "euclidean" + assert document_store.spec == {"serverless": {"region": "us-east-1", "cloud": "aws"}} def test_init_fails_wo_api_key(monkeypatch): monkeypatch.delenv("PINECONE_API_KEY", raising=False) with pytest.raises(ValueError): - PineconeDocumentStore( - environment="gcp-starter", + _ = PineconeDocumentStore( index="my_index", - ) + ).index + + +def test_convert_dict_spec_to_pinecone_object_serverless(): + dict_spec = {"serverless": {"region": "us-east-1", "cloud": "aws"}} + pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) + assert isinstance(pinecone_object, ServerlessSpec) + assert pinecone_object.region == "us-east-1" + assert pinecone_object.cloud == "aws" + + +def test_convert_dict_spec_to_pinecone_object_pod(): + dict_spec = {"pod": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"}} + pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) + + assert isinstance(pinecone_object, PodSpec) + assert pinecone_object.replicas == 1 + assert pinecone_object.shards == 1 + assert pinecone_object.pods == 1 + assert pinecone_object.pod_type == "p1.x1" + assert pinecone_object.environment == "us-west1-gcp" + + +def test_convert_dict_spec_to_pinecone_object_fail(): + dict_spec = { + "strange_key": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"} + } + with pytest.raises(ValueError): + PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec) + + +@pytest.mark.integration +@pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set") +def test_serverless_index_creation_from_scratch(sleep_time): + # we use a fixed index name to avoid hitting the limit of Pinecone's free tier (max 5 indexes) + # the index name is defined in the test matrix of the GitHub Actions workflow + # the default value is provided for local testing + index_name = os.environ.get("INDEX_NAME", "serverless-test-index") + + client = Pinecone(api_key=os.environ["PINECONE_API_KEY"]) + try: + client.delete_index(name=index_name) + except Exception: # noqa S110 + pass + + time.sleep(sleep_time) + + ds = PineconeDocumentStore( + index=index_name, + namespace="test", + batch_size=50, + dimension=30, + metric="euclidean", + spec={"serverless": {"region": "us-east-1", "cloud": "aws"}}, + ) + # Trigger the connection + _ = ds.index + + index_description = client.describe_index(name=index_name) + assert index_description["name"] == index_name + assert index_description["dimension"] == 30 + assert index_description["metric"] == "euclidean" + assert index_description["spec"]["serverless"]["region"] == "us-east-1" + assert index_description["spec"]["serverless"]["cloud"] == "aws" + + try: + client.delete_index(name=index_name) + except Exception: # noqa S110 + pass @pytest.mark.integration diff --git a/integrations/pinecone/tests/test_emebedding_retriever.py b/integrations/pinecone/tests/test_embedding_retriever.py similarity index 55% rename from integrations/pinecone/tests/test_emebedding_retriever.py rename to integrations/pinecone/tests/test_embedding_retriever.py index 80cc19010..99be75982 100644 --- a/integrations/pinecone/tests/test_emebedding_retriever.py +++ b/integrations/pinecone/tests/test_embedding_retriever.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack.utils import Secret from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever @@ -16,14 +18,20 @@ def test_init_default(): assert retriever.document_store == mock_store assert retriever.filters == {} assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + retriever = PineconeEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever.filter_policy == FilterPolicy.REPLACE -@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") + with pytest.raises(ValueError): + PineconeEmbeddingRetriever(document_store=mock_store, filter_policy="invalid") + + +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_to_dict(mock_pinecone, monkeypatch): monkeypatch.setenv("PINECONE_API_KEY", "env-api-key") - mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 512} document_store = PineconeDocumentStore( - environment="gcp-starter", index="default", namespace="test-namespace", batch_size=50, @@ -43,21 +51,23 @@ def test_to_dict(mock_pinecone, monkeypatch): "strict": True, "type": "env_var", }, - "environment": "gcp-starter", "index": "default", "namespace": "test-namespace", "batch_size": 50, "dimension": 512, + "spec": {"serverless": {"region": "us-east-1", "cloud": "aws"}}, + "metric": "cosine", }, "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", }, "filters": {}, "top_k": 10, + "filter_policy": "replace", }, } -@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") def test_from_dict(mock_pinecone, monkeypatch): data = { "type": "haystack_integrations.components.retrievers.pinecone.embedding_retriever.PineconeEmbeddingRetriever", @@ -71,11 +81,59 @@ def test_from_dict(mock_pinecone, monkeypatch): "strict": True, "type": "env_var", }, - "environment": "gcp-starter", "index": "default", "namespace": "test-namespace", "batch_size": 50, "dimension": 512, + "spec": {"serverless": {"region": "us-east-1", "cloud": "aws"}}, + "metric": "cosine", + }, + "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", + }, + "filters": {}, + "top_k": 10, + "filter_policy": "replace", + }, + } + + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + monkeypatch.setenv("PINECONE_API_KEY", "test-key") + retriever = PineconeEmbeddingRetriever.from_dict(data) + + document_store = retriever.document_store + assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) + assert document_store.index_name == "default" + assert document_store.namespace == "test-namespace" + assert document_store.batch_size == 50 + assert document_store.dimension == 512 + assert document_store.metric == "cosine" + assert document_store.spec == {"serverless": {"region": "us-east-1", "cloud": "aws"}} + + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + + +@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone") +def test_from_dict_no_filter_policy(mock_pinecone, monkeypatch): + data = { + "type": "haystack_integrations.components.retrievers.pinecone.embedding_retriever.PineconeEmbeddingRetriever", + "init_parameters": { + "document_store": { + "init_parameters": { + "api_key": { + "env_vars": [ + "PINECONE_API_KEY", + ], + "strict": True, + "type": "env_var", + }, + "index": "default", + "namespace": "test-namespace", + "batch_size": 50, + "dimension": 512, + "spec": {"serverless": {"region": "us-east-1", "cloud": "aws"}}, + "metric": "cosine", }, "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", }, @@ -84,20 +142,22 @@ def test_from_dict(mock_pinecone, monkeypatch): }, } - mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} + mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 512} monkeypatch.setenv("PINECONE_API_KEY", "test-key") retriever = PineconeEmbeddingRetriever.from_dict(data) document_store = retriever.document_store - assert document_store.environment == "gcp-starter" assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True) - assert document_store.index == "default" + assert document_store.index_name == "default" assert document_store.namespace == "test-namespace" assert document_store.batch_size == 50 assert document_store.dimension == 512 + assert document_store.metric == "cosine" + assert document_store.spec == {"serverless": {"region": "us-east-1", "cloud": "aws"}} assert retriever.filters == {} assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE def test_run(): diff --git a/integrations/pinecone/tests/test_filters.py b/integrations/pinecone/tests/test_filters.py index bb0855aa1..40c9cdb10 100644 --- a/integrations/pinecone/tests/test_filters.py +++ b/integrations/pinecone/tests/test_filters.py @@ -13,10 +13,6 @@ class TestFilters(FilterDocumentsTest): def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): for doc in received: - # Pinecone seems to convert strings to datetime objects (undocumented behavior) - # We convert them back to strings to compare them - if "date" in doc.meta: - doc.meta["date"] = doc.meta["date"].isoformat() # Pinecone seems to convert integers to floats (undocumented behavior) # We convert them back to integers to compare them if "number" in doc.meta: diff --git a/integrations/qdrant/CHANGELOG.md b/integrations/qdrant/CHANGELOG.md new file mode 100644 index 000000000..d17f549da --- /dev/null +++ b/integrations/qdrant/CHANGELOG.md @@ -0,0 +1,145 @@ +# Changelog + +## [integrations/qdrant-v4.1.2] - 2024-07-15 + +### 🐛 Bug Fixes + +- `qdrant` - Fallback to default filter policy when deserializing retrievers without the init parameter (#902) + +## [integrations/qdrant-v4.1.1] - 2024-07-10 + +### 🚀 Features + +- Add filter_policy to qdrant integration (#819) + +### 🐛 Bug Fixes + +- Errors in convert_filters_to_qdrant (#870) + +## [integrations/qdrant-v4.1.0] - 2024-07-03 + +### 🚀 Features + +- Add `score_threshold` to Qdrant Retrievers (#860) +- Qdrant - add support for BM42 (#864) + +## [integrations/qdrant-v4.0.0] - 2024-07-02 + +### 🚜 Refactor + +- [**breaking**] Qdrant - remove unused init parameters: `content_field`, `name_field`, `embedding_field`, and `duplicate_documents` (#861) +- [**breaking**] Qdrant - set `scale_score` default value to `False` (#862) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/qdrant-v3.8.1] - 2024-06-20 + +### 📚 Documentation + +- Added docstrings for QdrantDocumentStore (#808) + +## [integrations/qdrant-v3.8.0] - 2024-06-06 + +### 🚀 Features + +- Add force_disable_check_same_thread init param for Qdrant local client (#779) + +## [integrations/qdrant-v3.7.0] - 2024-05-24 + +### 🚀 Features + +- Make get_distance and recreate_collection public, replace deprecated recreate_collection function (#754) + +## [integrations/qdrant-v3.6.0] - 2024-05-24 + +### 🚀 Features + +- Defer database connection to the first usage (#748) + +## [integrations/qdrant-v3.5.0] - 2024-04-24 + +## [integrations/qdrant-v3.4.0] - 2024-04-23 + +### Qdrant + +- Add embedding retrieval example (#666) + +## [integrations/qdrant-v3.3.1] - 2024-04-12 + +### Qdrant + +- Add migration utility function for Sparse Embedding support (#659) + +## [integrations/qdrant-v3.3.0] - 2024-04-12 + +### 🚀 Features + +- *(Qdrant)* Start to work on sparse vector integration (#578) + +## [integrations/qdrant-v3.2.1] - 2024-04-09 + +### 🐛 Bug Fixes + +- Fix haystack-ai pin (#649) + + + +## [integrations/qdrant-v3.2.0] - 2024-03-27 + +### 🚀 Features + +- *(Qdrant)* Allow payload indexing + on disk vectors (#553) +- Qdrant datetime filtering support (#570) + +### 🐛 Bug Fixes + +- Fix linter errors (#282) + + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Fixes (#518) + + + +### 🚜 Refactor + +- [**breaking**] Qdrant - update secret management (#405) + +### 📚 Documentation + +- Update category slug (#442) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Generate API docs for Qdrant (#361) + +## [integrations/qdrant-v3.0.0] - 2024-01-22 + +### Refact + +- [**breaking**] Change import paths (#255) + +## [integrations/qdrant-v2.0.1] - 2024-01-18 + +### 🚀 Features + +- Add Qdrant integration (#98) + +### 🐛 Bug Fixes + +- Fix import paths for beta5 (#237) + + + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + + diff --git a/integrations/qdrant/pydoc/config.yml b/integrations/qdrant/pydoc/config.yml index 835eeb2e9..58ededdb5 100644 --- a/integrations/qdrant/pydoc/config.yml +++ b/integrations/qdrant/pydoc/config.yml @@ -17,7 +17,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Qdrant integration for Haystack category_slug: integrations-api title: Qdrant diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 05de42585..8b9c44cc7 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.0.1", "qdrant-client"] +dependencies = ["haystack-ai", "qdrant-client>=1.10.0"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -44,12 +44,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/qdrant-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -60,7 +62,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -105,7 +107,8 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Allow boolean arguments in function definition - "FBT001", "FBT002", + "FBT001", + "FBT002", # Ignore checks for possible passwords "S105", "S106", @@ -140,12 +143,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index 7befa3612..275a46f95 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -1,8 +1,11 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict from haystack.dataclasses.sparse_embedding import SparseEmbedding +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.qdrant import QdrantDocumentStore +from qdrant_client.http import models @component @@ -12,6 +15,7 @@ class QdrantEmbeddingRetriever: Usage example: ```python + from haystack.dataclasses import Document from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever from haystack_integrations.document_stores.qdrant import QdrantDocumentStore @@ -33,21 +37,28 @@ class QdrantEmbeddingRetriever: def __init__( self, document_store: QdrantDocumentStore, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], models.Filter]] = None, top_k: int = 10, - scale_score: bool = True, + scale_score: bool = False, return_embedding: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + score_threshold: Optional[float] = None, ): """ Create a QdrantEmbeddingRetriever component. :param document_store: An instance of QdrantDocumentStore. - :param filters: A dictionary with filters to narrow down the search space. Default is None. - :param top_k: The maximum number of documents to retrieve. Default is 10. - :param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True. - :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to retrieve. + :param scale_score: Whether to scale the scores of the retrieved documents or not. + :param return_embedding: Whether to return the embedding of the retrieved Documents. + :param filter_policy: Policy to determine how filters are applied. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the `similarity` function specified in the Document Store. + E.g. for cosine similarity only higher scores will be returned. - :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. + :raises ValueError: If `document_store` is not an instance of `QdrantDocumentStore`. """ if not isinstance(document_store, QdrantDocumentStore): @@ -59,6 +70,10 @@ def __init__( self._top_k = top_k self._scale_score = scale_score self._return_embedding = return_embedding + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._score_threshold = score_threshold def to_dict(self) -> Dict[str, Any]: """ @@ -72,8 +87,10 @@ def to_dict(self) -> Dict[str, Any]: document_store=self._document_store, filters=self._filters, top_k=self._top_k, + filter_policy=self._filter_policy.value, scale_score=self._scale_score, return_embedding=self._return_embedding, + score_threshold=self._score_threshold, ) d["init_parameters"]["document_store"] = self._document_store.to_dict() @@ -91,16 +108,21 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever": """ document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run( self, query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], models.Filter]] = None, top_k: Optional[int] = None, scale_score: Optional[bool] = None, return_embedding: Optional[bool] = None, + score_threshold: Optional[float] = None, ): """ Run the Embedding Retriever on the given input data. @@ -110,16 +132,20 @@ def run( :param top_k: The maximum number of documents to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. :param return_embedding: Whether to return the embedding of the retrieved Documents. + :param score_threshold: A minimal score threshold for the result. :returns: The retrieved documents. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + docs = self._document_store._query_by_embedding( query_embedding=query_embedding, - filters=filters or self._filters, + filters=filters, top_k=top_k or self._top_k, scale_score=scale_score or self._scale_score, return_embedding=return_embedding or self._return_embedding, + score_threshold=score_threshold or self._score_threshold, ) return {"documents": docs} @@ -134,7 +160,7 @@ class QdrantSparseEmbeddingRetriever: ```python from haystack_integrations.components.retrievers.qdrant import QdrantSparseEmbeddingRetriever from haystack_integrations.document_stores.qdrant import QdrantDocumentStore - from haystack.dataclasses.sparse_embedding import SparseEmbedding + from haystack.dataclasses import Document, SparseEmbedding document_store = QdrantDocumentStore( ":memory:", @@ -155,21 +181,28 @@ class QdrantSparseEmbeddingRetriever: def __init__( self, document_store: QdrantDocumentStore, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], models.Filter]] = None, top_k: int = 10, - scale_score: bool = True, + scale_score: bool = False, return_embedding: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + score_threshold: Optional[float] = None, ): """ Create a QdrantSparseEmbeddingRetriever component. :param document_store: An instance of QdrantDocumentStore. - :param filters: A dictionary with filters to narrow down the search space. Default is None. - :param top_k: The maximum number of documents to retrieve. Default is 10. - :param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True. - :param return_embedding: Whether to return the sparse embedding of the retrieved Documents. Default is False. - - :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to retrieve. + :param scale_score: Whether to scale the scores of the retrieved documents or not. + :param return_embedding: Whether to return the sparse embedding of the retrieved Documents. + :param filter_policy: Policy to determine how filters are applied. Defaults to "replace". + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + + :raises ValueError: If `document_store` is not an instance of `QdrantDocumentStore`. """ if not isinstance(document_store, QdrantDocumentStore): @@ -181,6 +214,10 @@ def __init__( self._top_k = top_k self._scale_score = scale_score self._return_embedding = return_embedding + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._score_threshold = score_threshold def to_dict(self) -> Dict[str, Any]: """ @@ -195,7 +232,9 @@ def to_dict(self) -> Dict[str, Any]: filters=self._filters, top_k=self._top_k, scale_score=self._scale_score, + filter_policy=self._filter_policy.value, return_embedding=self._return_embedding, + score_threshold=self._score_threshold, ) d["init_parameters"]["document_store"] = self._document_store.to_dict() @@ -213,35 +252,49 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantSparseEmbeddingRetriever": """ document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run( self, query_sparse_embedding: SparseEmbedding, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], models.Filter]] = None, top_k: Optional[int] = None, scale_score: Optional[bool] = None, return_embedding: Optional[bool] = None, + score_threshold: Optional[float] = None, ): """ Run the Sparse Embedding Retriever on the given input data. :param query_sparse_embedding: Sparse Embedding of the query. - :param filters: A dictionary with filters to narrow down the search space. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: The maximum number of documents to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. :param return_embedding: Whether to return the embedding of the retrieved Documents. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. :returns: The retrieved documents. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + docs = self._document_store._query_by_sparse( query_sparse_embedding=query_sparse_embedding, - filters=filters or self._filters, + filters=filters, top_k=top_k or self._top_k, scale_score=scale_score or self._scale_score, return_embedding=return_embedding or self._return_embedding, + score_threshold=score_threshold or self._score_threshold, ) return {"documents": docs} @@ -257,7 +310,7 @@ class QdrantHybridRetriever: ```python from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever from haystack_integrations.document_stores.qdrant import QdrantDocumentStore - from haystack.dataclasses.sparse_embedding import SparseEmbedding + from haystack.dataclasses import Document, SparseEmbedding document_store = QdrantDocumentStore( ":memory:", @@ -283,9 +336,11 @@ class QdrantHybridRetriever: def __init__( self, document_store: QdrantDocumentStore, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], models.Filter]] = None, top_k: int = 10, return_embedding: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + score_threshold: Optional[float] = None, ): """ Create a QdrantHybridRetriever component. @@ -294,6 +349,11 @@ def __init__( :param filters: A dictionary with filters to narrow down the search space. :param top_k: The maximum number of documents to retrieve. :param return_embedding: Whether to return the embeddings of the retrieved Documents. + :param filter_policy: Policy to determine how filters are applied. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. """ @@ -306,6 +366,10 @@ def __init__( self._filters = filters self._top_k = top_k self._return_embedding = return_embedding + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._score_threshold = score_threshold def to_dict(self) -> Dict[str, Any]: """ @@ -319,7 +383,9 @@ def to_dict(self) -> Dict[str, Any]: document_store=self._document_store.to_dict(), filters=self._filters, top_k=self._top_k, + filter_policy=self._filter_policy.value, return_embedding=self._return_embedding, + score_threshold=self._score_threshold, ) @classmethod @@ -334,6 +400,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantHybridRetriever": """ document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -341,28 +411,38 @@ def run( self, query_embedding: List[float], query_sparse_embedding: SparseEmbedding, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], models.Filter]] = None, top_k: Optional[int] = None, return_embedding: Optional[bool] = None, + score_threshold: Optional[float] = None, ): """ Run the Sparse Embedding Retriever on the given input data. :param query_embedding: Dense embedding of the query. :param query_sparse_embedding: Sparse embedding of the query. - :param filters: A dictionary with filters to narrow down the search space. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: The maximum number of documents to return. :param return_embedding: Whether to return the embedding of the retrieved Documents. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. :returns: The retrieved documents. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + docs = self._document_store._query_hybrid( query_embedding=query_embedding, query_sparse_embedding=query_sparse_embedding, - filters=filters or self._filters, + filters=filters, top_k=top_k or self._top_k, return_embedding=return_embedding or self._return_embedding, + score_threshold=score_threshold or self._score_threshold, ) return {"documents": docs} diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py index 96bd4f37a..01645a999 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py @@ -17,7 +17,6 @@ def convert_haystack_documents_to_qdrant_points( documents: List[Document], *, - embedding_field: str, use_sparse_embeddings: bool, ) -> List[rest.PointStruct]: points = [] @@ -26,7 +25,7 @@ def convert_haystack_documents_to_qdrant_points( if use_sparse_embeddings: vector = {} - dense_vector = payload.pop(embedding_field, None) + dense_vector = payload.pop("embedding", None) if dense_vector is not None: vector[DENSE_VECTORS_NAME] = dense_vector @@ -36,7 +35,7 @@ def convert_haystack_documents_to_qdrant_points( vector[SPARSE_VECTORS_NAME] = sparse_vector_instance else: - vector = payload.pop(embedding_field) or {} + vector = payload.pop("embedding") or {} _id = convert_id(payload.get("id")) point = rest.PointStruct( diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 120735411..d55cbd71c 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -5,7 +5,6 @@ import numpy as np import qdrant_client -from grpc import RpcError from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding @@ -50,6 +49,44 @@ def get_batches_from_generator(iterable, n): class QdrantDocumentStore: + """ + QdrantDocumentStore is a Document Store for Qdrant. + It can be used with any Qdrant instance: in-memory, disk-persisted, Docker-based, + and Qdrant Cloud Cluster deployments. + + Usage example by creating an in-memory instance: + + ```python + from haystack.dataclasses.document import Document + from haystack_integrations.document_stores.qdrant import QdrantDocumentStore + + document_store = QdrantDocumentStore( + ":memory:", + recreate_index=True + ) + document_store.write_documents([ + Document(content="This is first", embedding=[0.0]*5), + Document(content="This is second", embedding=[0.1, 0.2, 0.3, 0.4, 0.5]) + ]) + ``` + + Usage example with Qdrant Cloud: + + ```python + from haystack.dataclasses.document import Document + from haystack_integrations.document_stores.qdrant import QdrantDocumentStore + + document_store = QdrantDocumentStore( + url="https://xxxxxx-xxxxx-xxxxx-xxxx-xxxxxxxxx.us-east.aws.cloud.qdrant.io:6333", + api_key="", + ) + document_store.write_documents([ + Document(content="This is first", embedding=[0.0]*5), + Document(content="This is second", embedding=[0.1, 0.2, 0.3, 0.4, 0.5]) + ]) + ``` + """ + SIMILARITY: ClassVar[Dict[str, str]] = { "cosine": rest.Distance.COSINE, "dot_product": rest.Distance.DOT, @@ -66,20 +103,18 @@ def __init__( https: Optional[bool] = None, api_key: Optional[Secret] = None, prefix: Optional[str] = None, - timeout: Optional[float] = None, + timeout: Optional[int] = None, host: Optional[str] = None, path: Optional[str] = None, + force_disable_check_same_thread: bool = False, index: str = "Document", embedding_dim: int = 768, on_disk: bool = False, - content_field: str = "content", - name_field: str = "name", - embedding_field: str = "embedding", use_sparse_embeddings: bool = False, + sparse_idf: bool = False, similarity: str = "cosine", return_embedding: bool = False, progress_bar: bool = True, - duplicate_documents: str = "overwrite", recreate_index: bool = False, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, @@ -96,23 +131,94 @@ def __init__( scroll_size: int = 10_000, payload_fields_to_index: Optional[List[dict]] = None, ): - super().__init__() - - metadata = metadata or {} - self.client = qdrant_client.QdrantClient( - location=location, - url=url, - port=port, - grpc_port=grpc_port, - prefer_grpc=prefer_grpc, - https=https, - api_key=api_key.resolve_value() if api_key else None, - prefix=prefix, - timeout=timeout, - host=host, - path=path, - metadata=metadata, - ) + """ + :param location: + If `memory` - use in-memory Qdrant instance. + If `str` - use it as a URL parameter. + If `None` - use default values for host and port. + :param url: + Either host or str of `Optional[scheme], host, Optional[port], Optional[prefix]`. + :param port: + Port of the REST API interface. + :param grpc_port: + Port of the gRPC interface. + :param prefer_grpc: + If `True` - use gRPC interface whenever possible in custom methods. + :param https: + If `True` - use HTTPS(SSL) protocol. + :param api_key: + API key for authentication in Qdrant Cloud. + :param prefix: + If not `None` - add prefix to the REST URL path. + Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} + for REST API. + :param timeout: + Timeout for REST and gRPC API requests. + :param host: + Host name of Qdrant service. If ùrl` and `host` are `None`, set to `localhost`. + :param path: + Persistence path for QdrantLocal. + :param force_disable_check_same_thread: + For QdrantLocal, force disable check_same_thread. + Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. + :param index: + Name of the index. + :param embedding_dim: + Dimension of the embeddings. + :param on_disk: + Whether to store the collection on disk. + :param use_sparse_embedding: + If set to `True`, enables support for sparse embeddings. + :param sparse_idf: + If set to `True`, computes the Inverse Document Frequency (IDF) when using sparse embeddings. + It is required to use techniques like BM42. It is ignored if `use_sparse_embeddings` is `False`. + :param similarity: + The similarity metric to use. + :param return_embedding: + Whether to return embeddings in the search results. + :param progress_bar: + Whether to show a progress bar or not. + :param recreate_index: + Whether to recreate the index. + :param shard_number: + Number of shards in the collection. + :param replication_factor: + Replication factor for the collection. + Defines how many copies of each shard will be created. Effective only in distributed mode. + :param write_consistency_factor: + Write consistency factor for the collection. Minimum value is 1. + Defines how many replicas should apply to the operation for it to be considered successful. + Increasing this number makes the collection more resilient to inconsistencies + but will cause failures if not enough replicas are available. + Effective only in distributed mode. + :param on_disk_payload: + If `True`, the point's payload will not be stored in memory and + will be read from the disk every time it is requested. + This setting saves RAM by slightly increasing response time. + Note: indexed payload values remain in RAM. + :param hnsw_config: + Params for HNSW index. + :param optimizers_config: + Params for optimizer. + :param wal_config: + Params for Write-Ahead-Log. + :param quantization_config: + Params for quantization. If `None`, quantization will be disabled. + :param init_from: + Use data stored in another collection to initialize this collection. + :param wait_result_from_api: + Whether to wait for the result from the API after each request. + :param metadata: + Additional metadata to include with the documents. + :param write_batch_size: + The batch size for writing documents. + :param scroll_size: + The scroll size for reading documents. + :param payload_fields_to_index: + List of payload fields to index. + """ + + self._client = None # Store the Qdrant client specific attributes self.location = location @@ -126,7 +232,8 @@ def __init__( self.timeout = timeout self.host = host self.path = path - self.metadata = metadata + self.force_disable_check_same_thread = force_disable_check_same_thread + self.metadata = metadata or {} self.api_key = api_key # Store the Qdrant collection specific attributes @@ -143,26 +250,51 @@ def __init__( self.recreate_index = recreate_index self.payload_fields_to_index = payload_fields_to_index self.use_sparse_embeddings = use_sparse_embeddings - - # Make sure the collection is properly set up - self._set_up_collection( - index, embedding_dim, recreate_index, similarity, use_sparse_embeddings, on_disk, payload_fields_to_index - ) - + self.sparse_idf = use_sparse_embeddings and sparse_idf self.embedding_dim = embedding_dim self.on_disk = on_disk - self.content_field = content_field - self.name_field = name_field - self.embedding_field = embedding_field self.similarity = similarity self.index = index self.return_embedding = return_embedding self.progress_bar = progress_bar - self.duplicate_documents = duplicate_documents self.write_batch_size = write_batch_size self.scroll_size = scroll_size + @property + def client(self): + if not self._client: + self._client = qdrant_client.QdrantClient( + location=self.location, + url=self.url, + port=self.port, + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + https=self.https, + api_key=self.api_key.resolve_value() if self.api_key else None, + prefix=self.prefix, + timeout=self.timeout, + host=self.host, + path=self.path, + metadata=self.metadata, + force_disable_check_same_thread=self.force_disable_check_same_thread, + ) + # Make sure the collection is properly set up + self._set_up_collection( + self.index, + self.embedding_dim, + self.recreate_index, + self.similarity, + self.use_sparse_embeddings, + self.sparse_idf, + self.on_disk, + self.payload_fields_to_index, + ) + return self._client + def count_documents(self) -> int: + """ + Returns the number of documents present in the Document Store. + """ try: response = self.client.count( collection_name=self.index, @@ -176,13 +308,22 @@ def count_documents(self) -> int: def filter_documents( self, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], rest.Filter]] = None, ) -> List[Document]: - if filters and not isinstance(filters, dict): - msg = "Filter must be a dictionary" + """ + Returns the documents that match the provided filters. + + For a detailed specification of the filters, refer to the + [documentation](https://docs.haystack.deepset.ai/docs/metadata-filtering) + + :param filters: The filters to apply to the document list. + :returns: A list of documents that match the given filters. + """ + if filters and not isinstance(filters, dict) and not isinstance(filters, rest.Filter): + msg = "Filter must be a dictionary or an instance of `qdrant_client.http.models.Filter`" raise ValueError(msg) - if filters and "operator" not in filters: + if filters and not isinstance(filters, rest.Filter) and "operator" not in filters: filters = convert_legacy_filters(filters) return list( self.get_documents_generator( @@ -195,11 +336,26 @@ def write_documents( documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL, ): + """ + Writes documents to Qdrant using the specified policy. + The QdrantDocumentStore can handle duplicate documents based on the given policy. + The available policies are: + - `FAIL`: The operation will raise an error if any document already exists. + - `OVERWRITE`: Existing documents will be overwritten with the new ones. + - `SKIP`: Existing documents will be skipped, and only new documents will be added. + + :param documents: A list of Document objects to write to Qdrant. + :param policy: The policy for handling duplicate documents. + + :returns: The number of documents written to the document store. + """ for doc in documents: if not isinstance(doc, Document): msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}." raise ValueError(msg) - self._set_up_collection(self.index, self.embedding_dim, False, self.similarity, self.use_sparse_embeddings) + self._set_up_collection( + self.index, self.embedding_dim, False, self.similarity, self.use_sparse_embeddings, self.sparse_idf + ) if len(documents) == 0: logger.warning("Calling QdrantDocumentStore.write_documents() with empty list") @@ -216,7 +372,6 @@ def write_documents( for document_batch in batched_documents: batch = convert_haystack_documents_to_qdrant_points( document_batch, - embedding_field=self.embedding_field, use_sparse_embeddings=self.use_sparse_embeddings, ) @@ -230,6 +385,11 @@ def write_documents( return len(document_objects) def delete_documents(self, ids: List[str]): + """ + Deletes documents that match the provided `document_ids` from the document store. + + :param document_ids: the document ids to delete + """ ids = [convert_id(_id) for _id in ids] try: self.client.delete( @@ -244,10 +404,24 @@ def delete_documents(self, ids: List[str]): @classmethod def from_dict(cls, data: Dict[str, Any]) -> "QdrantDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ params = inspect.signature(self.__init__).parameters # type: ignore # All the __init__ params must be set as attributes # Set as init_parms without default values @@ -260,8 +434,15 @@ def to_dict(self) -> Dict[str, Any]: def get_documents_generator( self, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], rest.Filter]] = None, ) -> Generator[Document, None, None]: + """ + Returns a generator that yields documents from Qdrant based on the provided filters. + + :param filters: Filters applied to the retrieved documents. + :returns: A generator that yields documents retrieved from Qdrant. + """ + index = self.index qdrant_filters = convert_filters_to_qdrant(filters) @@ -290,6 +471,16 @@ def get_documents_by_id( ids: List[str], index: Optional[str] = None, ) -> List[Document]: + """ + Retrieves documents from Qdrant by their IDs. + + :param ids: + A list of document IDs to retrieve. + :param index: + The name of the index to retrieve documents from. + :returns: + A list of documents. + """ index = index or self.index documents: List[Document] = [] @@ -311,11 +502,31 @@ def get_documents_by_id( def _query_by_sparse( self, query_sparse_embedding: SparseEmbedding, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], rest.Filter]] = None, top_k: int = 10, - scale_score: bool = True, + scale_score: bool = False, return_embedding: bool = False, + score_threshold: Optional[float] = None, ) -> List[Document]: + """ + Queries Qdrant using a sparse embedding and returns the most relevant documents. + + :param query_sparse_embedding: Sparse embedding of the query. + :param filters: Filters applied to the retrieved documents. + :param top_k: Maximum number of documents to return. + :param scale_score: Whether to scale the scores of the retrieved documents. + :param return_embedding: Whether to return the embeddings of the retrieved documents. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + + :returns: List of documents that are most similar to `query_sparse_embedding`. + + :raises QdrantStoreError: + If the Document Store was initialized with `use_sparse_embeddings=False`. + """ + if not self.use_sparse_embeddings: message = ( "You are trying to query using sparse embeddings, but the Document Store " @@ -338,6 +549,7 @@ def _query_by_sparse( query_filter=qdrant_filters, limit=top_k, with_vectors=return_embedding, + score_threshold=score_threshold, ) results = [ convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) @@ -353,11 +565,27 @@ def _query_by_sparse( def _query_by_embedding( self, query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], rest.Filter]] = None, top_k: int = 10, - scale_score: bool = True, + scale_score: bool = False, return_embedding: bool = False, + score_threshold: Optional[float] = None, ) -> List[Document]: + """ + Queries Qdrant using a dense embedding and returns the most relevant documents. + + :param query_embedding: Dense embedding of the query. + :param filters: Filters applied to the retrieved documents. + :param top_k: Maximum number of documents to return. + :param scale_score: Whether to scale the scores of the retrieved documents. + :param return_embedding: Whether to return the embeddings of the retrieved documents. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + + :returns: List of documents that are most similar to `query_embedding`. + """ qdrant_filters = convert_filters_to_qdrant(filters) points = self.client.search( @@ -369,6 +597,7 @@ def _query_by_embedding( query_filter=qdrant_filters, limit=top_k, with_vectors=return_embedding, + score_threshold=score_threshold, ) results = [ convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) @@ -388,9 +617,10 @@ def _query_hybrid( self, query_embedding: List[float], query_sparse_embedding: SparseEmbedding, - filters: Optional[Dict[str, Any]] = None, + filters: Optional[Union[Dict[str, Any], rest.Filter]] = None, top_k: int = 10, return_embedding: bool = False, + score_threshold: Optional[float] = None, ) -> List[Document]: """ Retrieves documents based on dense and sparse embeddings and fuses the results using Reciprocal Rank Fusion. @@ -400,9 +630,13 @@ def _query_hybrid( :param query_embedding: Dense embedding of the query. :param query_sparse_embedding: Sparse embedding of the query. - :param filters: Filters applied to the retrieved Documents. - :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved documents. + :param top_k: Maximum number of documents to return. :param return_embedding: Whether to return the embeddings of the retrieved documents. + :param score_threshold: A minimal score threshold for the result. + Score of the returned result might be higher or smaller than the threshold + depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. :returns: List of Document that are most similar to `query_embedding` and `query_sparse_embedding`. @@ -433,6 +667,7 @@ def _query_hybrid( limit=top_k, with_payload=True, with_vector=return_embedding, + score_threshold=score_threshold, ) dense_request = rest.SearchRequest( @@ -464,7 +699,17 @@ def _query_hybrid( return results - def _get_distance(self, similarity: str) -> rest.Distance: + def get_distance(self, similarity: str) -> rest.Distance: + """ + Retrieves the distance metric for the specified similarity measure. + + :param similarity: + The similarity measure to retrieve the distance. + :returns: + The corresponding rest.Distance object. + :raises QdrantStoreError: + If the provided similarity measure is not supported. + """ try: return self.SIMILARITY[similarity] except KeyError as ke: @@ -495,34 +740,48 @@ def _set_up_collection( recreate_collection: bool, similarity: str, use_sparse_embeddings: bool, + sparse_idf: bool, on_disk: bool = False, payload_fields_to_index: Optional[List[dict]] = None, ): - distance = self._get_distance(similarity) + """ + Sets up the Qdrant collection with the specified parameters. + :param collection_name: + The name of the collection to set up. + :param embedding_dim: + The dimension of the embeddings. + :param recreate_collection: + Whether to recreate the collection if it already exists. + :param similarity: + The similarity measure to use. + :param use_sparse_embeddings: + Whether to use sparse embeddings. + :param sparse_idf: + Whether to compute the Inverse Document Frequency (IDF) when using sparse embeddings. Required for BM42. + :param on_disk: + Whether to store the collection on disk. + :param payload_fields_to_index: + List of payload fields to index. + + :raises QdrantStoreError: + If the collection exists with incompatible settings. + :raises ValueError: + If the collection exists with a different similarity measure or embedding dimension. + + """ + distance = self.get_distance(similarity) - if recreate_collection: + if recreate_collection or not self.client.collection_exists(collection_name): # There is no need to verify the current configuration of that - # collection. It might be just recreated again. - self._recreate_collection(collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings) + # collection. It might be just recreated again or does not exist yet. + self.recreate_collection( + collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings, sparse_idf + ) # Create Payload index if payload_fields_to_index is provided self._create_payload_index(collection_name, payload_fields_to_index) return - try: - # Check if the collection already exists and validate its - # current configuration with the parameters. - collection_info = self.client.get_collection(collection_name) - except (UnexpectedResponse, RpcError, ValueError): - # That indicates the collection does not exist, so it can be - # safely created with any configuration. - # - # Qdrant local raises ValueError if the collection is not found, but - # with the remote server UnexpectedResponse / RpcError is raised. - # Until that's unified, we need to catch both. - self._recreate_collection(collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings) - # Create Payload index if payload_fields_to_index is provided - self._create_payload_index(collection_name, payload_fields_to_index) - return + collection_info = self.client.get_collection(collection_name) has_named_vectors = ( isinstance(collection_info.config.params.vectors, dict) @@ -573,14 +832,37 @@ def _set_up_collection( ) raise ValueError(msg) - def _recreate_collection( + def recreate_collection( self, collection_name: str, distance, embedding_dim: int, - on_disk: bool, - use_sparse_embeddings: bool, + on_disk: Optional[bool] = None, + use_sparse_embeddings: Optional[bool] = None, + sparse_idf: bool = False, ): + """ + Recreates the Qdrant collection with the specified parameters. + + :param collection_name: + The name of the collection to recreate. + :param distance: + The distance metric to use for the collection. + :param embedding_dim: + The dimension of the embeddings. + :param on_disk: + Whether to store the collection on disk. + :param use_sparse_embeddings: + Whether to use sparse embeddings. + :param sparse_idf: + Whether to compute the Inverse Document Frequency (IDF) when using sparse embeddings. Required for BM42. + """ + if on_disk is None: + on_disk = self.on_disk + + if use_sparse_embeddings is None: + use_sparse_embeddings = self.use_sparse_embeddings + # dense vectors configuration vectors_config = rest.VectorParams(size=embedding_dim, on_disk=on_disk, distance=distance) @@ -592,11 +874,15 @@ def _recreate_collection( SPARSE_VECTORS_NAME: rest.SparseVectorParams( index=rest.SparseIndexParams( on_disk=on_disk, - ) + ), + modifier=rest.Modifier.IDF if sparse_idf else None, ), } - self.client.recreate_collection( + if self.client.collection_exists(collection_name): + self.client.delete_collection(collection_name) + + self.client.create_collection( collection_name=collection_name, vectors_config=vectors_config, sparse_vectors_config=sparse_vectors_config if use_sparse_embeddings else None, @@ -623,12 +909,7 @@ def _handle_duplicate_documents( :param documents: A list of Haystack Document objects. :param index: name of the index - :param duplicate_documents: Handle duplicates document based on parameter options. - Parameter options : ( 'skip','overwrite','fail') - skip (default option): Ignore the duplicates documents - overwrite: Update any existing documents with the same ID when adding documents. - fail: an error is raised if the document ID of the document being added already - exists. + :param policy: The duplicate policy to use when writing documents. :returns: A list of Haystack Document objects. """ @@ -648,10 +929,10 @@ def _handle_duplicate_documents( def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]: """ - Drop duplicates documents based on same hash ID + Drop duplicate documents based on same hash ID. :param documents: A list of Haystack Document objects. - :param index: name of the index + :param index: Name of the index. :returns: A list of Haystack Document objects. """ _hash_ids: Set = set() diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index c4387b1e5..69fd7cbbd 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -4,27 +4,55 @@ from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError from qdrant_client.http import models -from .converters import convert_id - COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() def convert_filters_to_qdrant( - filter_term: Optional[Union[List[dict], dict]] = None, -) -> Optional[models.Filter]: - """Converts Haystack filters to the format used by Qdrant.""" + filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True +) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]: + """Converts Haystack filters to the format used by Qdrant. + + :param filter_term: the haystack filter to be converted to qdrant. + :param is_parent_call: indicates if this is the top-level call to the function. If True, the function returns + a single models.Filter object; if False, it may return a list of filters or conditions for further processing. + + :returns: a single Qdrant Filter in the parent call or a list of such Filters in recursive calls. + + :raises FilterError: If the invalid filter criteria is provided or if an unknown operator is encountered. + + """ + if isinstance(filter_term, models.Filter): + return filter_term if not filter_term: return None - must_clauses, should_clauses, must_not_clauses = [], [], [] + must_clauses: List[models.Filter] = [] + should_clauses: List[models.Filter] = [] + must_not_clauses: List[models.Filter] = [] + # Indicates if there are multiple same LOGICAL OPERATORS on each level + # and prevents them from being combined + same_operator_flag = False + conditions, qdrant_filter, current_level_operators = ( + [], + [], + [], + ) if isinstance(filter_term, dict): filter_term = [filter_term] + # ======== IDENTIFY FILTER ITEMS ON EACH LEVEL ======== + for item in filter_term: operator = item.get("operator") + + # Check for repeated similar operators on each level + same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS + if not same_operator_flag: + current_level_operators.append(operator) + if operator is None: msg = "Operator not found in filters" raise FilterError(msg) @@ -33,12 +61,23 @@ def convert_filters_to_qdrant( msg = f"'conditions' not found for '{operator}'" raise FilterError(msg) - if operator == "AND": - must_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) - elif operator == "OR": - should_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) - elif operator == "NOT": - must_not_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + if operator in LOGICAL_OPERATORS: + # Recursively process nested conditions + current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or [] + + # When same_operator_flag is set to True, + # ensure each clause is appended as an independent list to avoid merging distinct clauses. + if operator == "AND": + must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter + elif operator == "OR": + should_clauses = ( + [should_clauses, current_filter] if same_operator_flag else should_clauses + current_filter + ) + elif operator == "NOT": + must_not_clauses = ( + [must_not_clauses, current_filter] if same_operator_flag else must_not_clauses + current_filter + ) + elif operator in COMPARISON_OPERATORS: field = item.get("field") value = item.get("value") @@ -46,20 +85,106 @@ def convert_filters_to_qdrant( msg = f"'field' or 'value' not found for '{operator}'" raise FilterError(msg) - must_clauses.extend(_parse_comparison_operation(comparison_operation=operator, key=field, value=value)) + parsed_conditions = _parse_comparison_operation(comparison_operation=operator, key=field, value=value) + + # check if the parsed_conditions are models.Filter or models.Condition + for condition in parsed_conditions: + if isinstance(condition, models.Filter): + qdrant_filter.append(condition) + else: + conditions.append(condition) + else: msg = f"Unknown operator {operator} used in filters" raise FilterError(msg) - payload_filter = models.Filter( - must=must_clauses or None, - should=should_clauses or None, - must_not=must_not_clauses or None, - ) + # ======== PROCESS FILTER ITEMS ON EACH LEVEL ======== + + # If same logical operators have separate clauses, create separate filters + if same_operator_flag: + qdrant_filter = build_filters_for_repeated_operators( + must_clauses, should_clauses, must_not_clauses, qdrant_filter + ) + + # else append a single Filter for existing clauses + elif must_clauses or should_clauses or must_not_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + ) + + # In case of parent call, a single Filter is returned + if is_parent_call: + # If qdrant_filter has just a single Filter in parent call, + # then it might be returned instead. + if len(qdrant_filter) == 1 and isinstance(qdrant_filter[0], models.Filter): + return qdrant_filter[0] + else: + must_clauses.extend(conditions) + return models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) - filter_result = _squeeze_filter(payload_filter) + # Store conditions of each level in output of the loop + elif conditions: + qdrant_filter.extend(conditions) - return filter_result + return qdrant_filter + + +def build_filters_for_repeated_operators( + must_clauses, + should_clauses, + must_not_clauses, + qdrant_filter, +) -> List[models.Filter]: + """ + Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator. + + :param must_clauses: a nested list of must clauses or an empty list. + :param should_clauses: a nested list of should clauses or an empty list. + :param must_not_clauses: a nested list of must_not clauses or an empty list. + :param qdrant_filter: a list where the generated Filter objects will be appended. + This list will be modified in-place. + + + :returns: the modified `qdrant_filter` list with appended generated Filter objects. + """ + + if any(isinstance(i, list) for i in must_clauses): + for i in must_clauses: + qdrant_filter.append( + models.Filter( + must=i or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + ) + if any(isinstance(i, list) for i in should_clauses): + for i in should_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=i or None, + must_not=must_not_clauses or None, + ) + ) + if any(isinstance(i, list) for i in must_not_clauses): + for i in must_clauses: + qdrant_filter.append( + models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=i or None, + ) + ) + + return qdrant_filter def _parse_comparison_operation( @@ -91,7 +216,7 @@ def _parse_comparison_operation( def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition: if isinstance(value, str) and " " in value: - models.FieldCondition(key=key, match=models.MatchText(text=value)) + return models.FieldCondition(key=key, match=models.MatchText(text=value)) return models.FieldCondition(key=key, match=models.MatchValue(value=value)) @@ -183,52 +308,6 @@ def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Cond raise FilterError(msg) -def _build_has_id_condition(id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: - return models.HasIdCondition( - has_id=[ - # Ids are converted into their internal representation - convert_id(item) - for item in id_values - ] - ) - - -def _squeeze_filter(payload_filter: models.Filter) -> models.Filter: - """ - Simplify given payload filter, if the nested structure might be unnested. - That happens if there is a single clause in that filter. - :param payload_filter: - :returns: - """ - filter_parts = { - "must": payload_filter.must, - "should": payload_filter.should, - "must_not": payload_filter.must_not, - } - - total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) - if total_clauses == 0 or total_clauses > 1: - return payload_filter - - # Payload filter has just a single clause provided (either must, should - # or must_not). If that single clause is also of a models.Filter type, - # then it might be returned instead. - for part_name, filter_part in filter_parts.items(): - if not filter_part: - continue - - subfilter = filter_part[0] - if not isinstance(subfilter, models.Filter): - # The inner statement is a simple condition like models.FieldCondition - # so it cannot be simplified. - continue - - if subfilter.must: - return models.Filter(**{part_name: subfilter.must}) - - return payload_filter - - def is_datetime_string(value: str) -> bool: try: datetime.fromisoformat(value) diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 6c8e46710..3871dbff0 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -22,14 +22,12 @@ def test_to_dict(): "index": "test", "embedding_dim": 768, "on_disk": False, - "content_field": "content", - "name_field": "name", - "embedding_field": "embedding", + "force_disable_check_same_thread": False, "use_sparse_embeddings": False, + "sparse_idf": False, "similarity": "cosine", "return_embedding": False, "progress_bar": True, - "duplicate_documents": "overwrite", "recreate_index": False, "shard_number": None, "replication_factor": None, @@ -61,14 +59,12 @@ def test_from_dict(): "index": "test", "embedding_dim": 768, "on_disk": False, - "content_field": "content", - "name_field": "name", - "embedding_field": "embedding", + "force_disable_check_same_thread": False, "use_sparse_embeddings": True, + "sparse_idf": True, "similarity": "cosine", "return_embedding": False, "progress_bar": True, - "duplicate_documents": "overwrite", "recreate_index": True, "shard_number": None, "quantization_config": None, @@ -85,15 +81,13 @@ def test_from_dict(): assert all( [ document_store.index == "test", - document_store.content_field == "content", - document_store.name_field == "name", - document_store.embedding_field == "embedding", + document_store.force_disable_check_same_thread is False, document_store.use_sparse_embeddings is True, + document_store.sparse_idf is True, document_store.on_disk is False, document_store.similarity == "cosine", document_store.return_embedding is False, document_store.progress_bar, - document_store.duplicate_documents == "overwrite", document_store.recreate_index is True, document_store.shard_number is None, document_store.replication_factor is None, diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py index cbd5c62d0..c388a10cf 100644 --- a/integrations/qdrant/tests/test_document_store.py +++ b/integrations/qdrant/tests/test_document_store.py @@ -12,7 +12,12 @@ WriteDocumentsTest, _random_embeddings, ) -from haystack_integrations.document_stores.qdrant.document_store import QdrantDocumentStore, QdrantStoreError +from haystack_integrations.document_stores.qdrant.document_store import ( + SPARSE_VECTORS_NAME, + QdrantDocumentStore, + QdrantStoreError, +) +from qdrant_client.http import models as rest class TestQdrantDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): @@ -26,6 +31,11 @@ def document_store(self) -> QdrantDocumentStore: use_sparse_embeddings=False, ) + def test_init_is_lazy(self): + with patch("haystack_integrations.document_stores.qdrant.document_store.qdrant_client") as mocked_qdrant: + QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) + mocked_qdrant.assert_not_called() + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. @@ -44,6 +54,23 @@ def test_write_documents(self, document_store: QdrantDocumentStore): with pytest.raises(DuplicateDocumentError): document_store.write_documents(docs, DuplicatePolicy.FAIL) + def test_sparse_configuration(self): + document_store = QdrantDocumentStore( + ":memory:", + recreate_index=True, + use_sparse_embeddings=True, + sparse_idf=True, + ) + + client = document_store.client + sparse_config = client.get_collection("Document").config.params.sparse_vectors + + assert SPARSE_VECTORS_NAME in sparse_config + + # check that the `sparse_idf` parameter takes effect + assert hasattr(sparse_config[SPARSE_VECTORS_NAME], "modifier") + assert sparse_config[SPARSE_VECTORS_NAME].modifier == rest.Modifier.IDF + def test_query_hybrid(self, generate_sparse_embedding): document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index 2b35dfebc..fd070bda9 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -5,6 +5,7 @@ from haystack.testing.document_store import FilterDocumentsTest from haystack.utils.filters import FilterError from haystack_integrations.document_stores.qdrant import QdrantDocumentStore +from qdrant_client.http import models class TestQdrantStoreBaseTests(FilterDocumentsTest): @@ -17,6 +18,21 @@ def document_store(self) -> QdrantDocumentStore: wait_result_from_api=True, ) + def test_filter_documents_with_qdrant_filters(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + filters=models.Filter( + must_not=[ + models.FieldCondition(key="meta.number", match=models.MatchValue(value=100)), + models.FieldCondition(key="meta.name", match=models.MatchValue(value="name_0")), + ] + ) + ) + self.assert_documents_are_equal( + result, + [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], + ) + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. @@ -45,6 +61,112 @@ def test_not_operator(self, document_store, filterable_docs): [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], ) + def test_filter_criteria(self, document_store): + documents = [ + Document( + content="This is test document 1.", + meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}}, + ), + Document( + content="This is test document 2.", + meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}}, + ), + Document( + content="This is test document 3.", + meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}}, + ), + ] + + document_store.write_documents(documents) + filter_criteria = { + "operator": "AND", + "conditions": [ + {"field": "meta.file_name", "operator": "in", "value": ["file1", "file2"]}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85}, + {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85}, + ], + }, + ], + } + result = document_store.filter_documents(filter_criteria) + self.assert_documents_are_equal( + result, + [ + d + for d in documents + if (d.meta.get("file_name") in ["file1", "file2"]) + and ( + (d.meta.get("classification").get("details").get("category1") >= 0.85) + or (d.meta.get("classification").get("details").get("category2") >= 0.85) + ) + ], + ) + + def test_complex_filter_criteria(self, document_store): + documents = [ + Document( + content="This is test document 1.", + meta={ + "file_name": "file1", + "classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}}, + }, + ), + Document( + content="This is test document 2.", + meta={ + "file_name": "file2", + "classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}}, + }, + ), + Document( + content="This is test document 3.", + meta={ + "file_name": "file3", + "classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}}, + }, + ), + ] + + document_store.write_documents(documents) + filter_criteria = { + "operator": "AND", + "conditions": [ + {"field": "meta.file_name", "operator": "in", "value": ["file1", "file2", "file3"]}, + { + "operator": "AND", + "conditions": [ + {"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85}, + { + "operator": "OR", + "conditions": [ + {"field": "meta.classification.details.category2", "operator": ">=", "value": 0.8}, + {"field": "meta.classification.details.category3", "operator": ">=", "value": 0.9}, + ], + }, + ], + }, + ], + } + result = document_store.filter_documents(filter_criteria) + self.assert_documents_are_equal( + result, + [ + d + for d in documents + if (d.meta.get("file_name") in ["file1", "file2", "file3"]) + and ( + (d.meta.get("classification").get("details").get("category1") >= 0.85) + and ( + (d.meta.get("classification").get("details").get("category2") >= 0.8) + or (d.meta.get("classification").get("details").get("category3") >= 0.9) + ) + ) + ], + ) + # ======== OVERRIDES FOR NONE VALUED FILTERS ======== def test_comparison_equal_with_none(self, document_store, filterable_docs): diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 47fec5968..a92f6917f 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -1,7 +1,9 @@ from typing import List from unittest.mock import Mock +import pytest from haystack.dataclasses import Document, SparseEmbedding +from haystack.document_stores.types import FilterPolicy from haystack.testing.document_store import ( FilterableDocsFixtureMixin, _random_embeddings, @@ -21,7 +23,15 @@ def test_init_default(self): assert retriever._document_store == document_store assert retriever._filters is None assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding is False + assert retriever._score_threshold is None + + retriever = QdrantEmbeddingRetriever(document_store=document_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + QdrantEmbeddingRetriever(document_store=document_store, filter_policy="invalid") def test_to_dict(self): document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False) @@ -47,14 +57,12 @@ def test_to_dict(self): "index": "test", "embedding_dim": 768, "on_disk": False, - "content_field": "content", - "name_field": "name", - "embedding_field": "embedding", + "force_disable_check_same_thread": False, "use_sparse_embeddings": False, + "sparse_idf": False, "similarity": "cosine", "return_embedding": False, "progress_bar": True, - "duplicate_documents": "overwrite", "recreate_index": False, "shard_number": None, "replication_factor": None, @@ -74,8 +82,10 @@ def test_to_dict(self): }, "filters": None, "top_k": 10, - "scale_score": True, + "filter_policy": "replace", + "scale_score": False, "return_embedding": False, + "score_threshold": None, }, } @@ -89,8 +99,10 @@ def test_from_dict(self): }, "filters": None, "top_k": 5, + "filter_policy": "replace", "scale_score": False, "return_embedding": True, + "score_threshold": None, }, } retriever = QdrantEmbeddingRetriever.from_dict(data) @@ -98,8 +110,10 @@ def test_from_dict(self): assert retriever._document_store.index == "test" assert retriever._filters is None assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._scale_score is False assert retriever._return_embedding is True + assert retriever._score_threshold is None def test_run(self, filterable_docs: List[Document]): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False) @@ -117,6 +131,53 @@ def test_run(self, filterable_docs: List[Document]): for document in results: assert document.embedding is None + def test_run_filters(self, filterable_docs: List[Document]): + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False) + + document_store.write_documents(filterable_docs) + + retriever = QdrantEmbeddingRetriever( + document_store=document_store, + filters={"field": "meta.name", "operator": "==", "value": "name_0"}, + filter_policy=FilterPolicy.MERGE, + ) + + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"] + assert len(results) == 3 + + results = retriever.run( + query_embedding=_random_embeddings(768), + top_k=5, + filters={"field": "meta.chapter", "operator": "==", "value": "abstract"}, + return_embedding=False, + )["documents"] + assert len(results) == 3 + + for document in results: + assert document.embedding is None + + def test_run_with_score_threshold(self): + document_store = QdrantDocumentStore( + embedding_dim=4, location=":memory:", similarity="cosine", index="Boi", use_sparse_embeddings=False + ) + + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[-0.1, -0.9, -10.0, -0.2], + ), + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.5, 1.0]), + ] + ) + + retriever = QdrantEmbeddingRetriever(document_store=document_store) + results = retriever.run( + query_embedding=[0.9, 0.9, 0.9, 0.9], top_k=5, return_embedding=False, score_threshold=0.5 + )["documents"] + assert len(results) == 2 + def test_run_with_sparse_activated(self, filterable_docs: List[Document]): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) @@ -143,7 +204,15 @@ def test_init_default(self): assert retriever._document_store == document_store assert retriever._filters is None assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding is False + assert retriever._score_threshold is None + + retriever = QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="invalid") def test_to_dict(self): document_store = QdrantDocumentStore(location=":memory:", index="test") @@ -169,14 +238,12 @@ def test_to_dict(self): "index": "test", "embedding_dim": 768, "on_disk": False, - "content_field": "content", - "name_field": "name", - "embedding_field": "embedding", + "force_disable_check_same_thread": False, "use_sparse_embeddings": False, + "sparse_idf": False, "similarity": "cosine", "return_embedding": False, "progress_bar": True, - "duplicate_documents": "overwrite", "recreate_index": False, "shard_number": None, "replication_factor": None, @@ -196,8 +263,10 @@ def test_to_dict(self): }, "filters": None, "top_k": 10, - "scale_score": True, + "scale_score": False, "return_embedding": False, + "filter_policy": "replace", + "score_threshold": None, }, } @@ -213,6 +282,33 @@ def test_from_dict(self): "top_k": 5, "scale_score": False, "return_embedding": True, + "filter_policy": "replace", + "score_threshold": None, + }, + } + retriever = QdrantSparseEmbeddingRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._scale_score is False + assert retriever._return_embedding is True + assert retriever._score_threshold is None + + def test_from_dict_no_filter_policy(self): + data = { + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "scale_score": False, + "return_embedding": True, + "score_threshold": None, }, } retriever = QdrantSparseEmbeddingRetriever.from_dict(data) @@ -220,8 +316,10 @@ def test_from_dict(self): assert retriever._document_store.index == "test" assert retriever._filters is None assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE assert retriever._scale_score is False assert retriever._return_embedding is True + assert retriever._score_threshold is None def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) @@ -252,7 +350,15 @@ def test_init_default(self): assert retriever._document_store == document_store assert retriever._filters is None assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding is False + assert retriever._score_threshold is None + + retriever = QdrantHybridRetriever(document_store=document_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + QdrantHybridRetriever(document_store=document_store, filter_policy="invalid") def test_to_dict(self): document_store = QdrantDocumentStore(location=":memory:", index="test") @@ -278,14 +384,12 @@ def test_to_dict(self): "index": "test", "embedding_dim": 768, "on_disk": False, - "content_field": "content", - "name_field": "name", - "embedding_field": "embedding", + "force_disable_check_same_thread": False, "use_sparse_embeddings": False, + "sparse_idf": False, "similarity": "cosine", "return_embedding": False, "progress_bar": True, - "duplicate_documents": "overwrite", "recreate_index": False, "shard_number": None, "replication_factor": None, @@ -305,11 +409,37 @@ def test_to_dict(self): }, "filters": None, "top_k": 5, + "filter_policy": "replace", "return_embedding": True, + "score_threshold": None, }, } def test_from_dict(self): + data = { + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantHybridRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "filter_policy": "replace", + "return_embedding": True, + "score_threshold": None, + }, + } + retriever = QdrantHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._return_embedding + assert retriever._score_threshold is None + + def test_from_dict_no_filter_policy(self): data = { "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantHybridRetriever", "init_parameters": { @@ -320,6 +450,7 @@ def test_from_dict(self): "filters": None, "top_k": 5, "return_embedding": True, + "score_threshold": None, }, } retriever = QdrantHybridRetriever.from_dict(data) @@ -327,7 +458,9 @@ def test_from_dict(self): assert retriever._document_store.index == "test" assert retriever._filters is None assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE assert retriever._return_embedding + assert retriever._score_threshold is None def test_run(self): mock_store = Mock(spec=QdrantDocumentStore) diff --git a/integrations/ragas/CHANGELOG.md b/integrations/ragas/CHANGELOG.md new file mode 100644 index 000000000..7055f1931 --- /dev/null +++ b/integrations/ragas/CHANGELOG.md @@ -0,0 +1,47 @@ +# Changelog + +## [integrations/ragas-v1.0.0] - 2024-07-24 + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Ragas - remove context relevancy metric (#917) + +## [integrations/ragas-v0.2.0] - 2024-04-23 + +## [integrations/ragas-v0.1.3] - 2024-04-09 + +### 🐛 Bug Fixes + +- Fix haystack-ai pin (#649) + + + +### 📚 Documentation + +- Disable-class-def (#556) + +## [integrations/ragas-v0.1.2] - 2024-03-08 + +### 📚 Documentation + +- Update `ragas-haystack` docstrings (#529) + +## [integrations/ragas-v0.1.1] - 2024-02-23 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +### Build + +- Pin `ragas` dependency to `0.1.1` (#476) + + diff --git a/integrations/ragas/pydoc/config.yml b/integrations/ragas/pydoc/config.yml index 3a8e843fe..97d8d808e 100644 --- a/integrations/ragas/pydoc/config.yml +++ b/integrations/ragas/pydoc/config.yml @@ -18,7 +18,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Ragas integration for Haystack category_slug: integrations-api title: Ragas diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index a763ab17c..edc33eee1 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ragas"] +dependencies = ["haystack-ai", "ragas>=0.1.11"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ragas" @@ -41,12 +41,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/ragas-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", "pytest-asyncio"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -57,7 +59,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -140,12 +142,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py index 06b29bedf..5d6ed16bc 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py @@ -12,7 +12,6 @@ AspectCritique, # type: ignore ContextPrecision, # type: ignore ContextRecall, # type: ignore - ContextRelevancy, # type: ignore ContextUtilization, # type: ignore Faithfulness, # type: ignore ) @@ -81,10 +80,6 @@ class RagasMetric(RagasBaseEnum): #: Parameters - `name: str, definition: str, strictness: int` ASPECT_CRITIQUE = "aspect_critique" - #: Context relevancy.\ - #: Inputs - `questions: List[str], contexts: List[List[str]]` - CONTEXT_RELEVANCY = "context_relevancy" - #: Answer relevancy.\ #: Inputs - `questions: List[str], contexts: List[List[str]], responses: List[str]`\ #: Parameters - `strictness: int` @@ -329,11 +324,6 @@ def aspect_critique(output: Result, _: RagasMetric, metric_params: Optional[Dict OutputConverters.aspect_critique, init_parameters=["name", "definition", "strictness"], ), - RagasMetric.CONTEXT_RELEVANCY: MetricDescriptor.new( - RagasMetric.CONTEXT_RELEVANCY, - ContextRelevancy, - InputConverters.question_context, # type: ignore - ), RagasMetric.ANSWER_RELEVANCY: MetricDescriptor.new( RagasMetric.ANSWER_RELEVANCY, AnswerRelevancy, diff --git a/integrations/ragas/tests/test_evaluator.py b/integrations/ragas/tests/test_evaluator.py index 0decc96cd..fc8901c32 100644 --- a/integrations/ragas/tests/test_evaluator.py +++ b/integrations/ragas/tests/test_evaluator.py @@ -51,7 +51,6 @@ def evaluate(self, _, metric: Metric, **kwargs): RagasMetric.CONTEXT_UTILIZATION: Result(scores=Dataset.from_list([{"context_utilization": 1.0}])), RagasMetric.CONTEXT_RECALL: Result(scores=Dataset.from_list([{"context_recall": 0.9}])), RagasMetric.ASPECT_CRITIQUE: Result(scores=Dataset.from_list([{"harmfulness": 1.0}])), - RagasMetric.CONTEXT_RELEVANCY: Result(scores=Dataset.from_list([{"context_relevancy": 1.0}])), RagasMetric.ANSWER_RELEVANCY: Result(scores=Dataset.from_list([{"answer_relevancy": 0.4}])), } assert isinstance(metric, Metric) @@ -76,7 +75,6 @@ def evaluate(self, _, metric: Metric, **kwargs): "large?", }, ), - (RagasMetric.CONTEXT_RELEVANCY, None), (RagasMetric.ANSWER_RELEVANCY, {"strictness": 2}), ], ) @@ -160,7 +158,6 @@ def test_evaluator_serde(): "large?", }, ), - (RagasMetric.CONTEXT_RELEVANCY, {"questions": [], "contexts": []}, None), (RagasMetric.ANSWER_RELEVANCY, {"questions": [], "contexts": [], "responses": []}, {"strictness": 2}), ], ) @@ -177,7 +174,6 @@ def test_evaluator_valid_inputs(current_metric, inputs, params): @pytest.mark.parametrize( "current_metric, inputs, error_string, params", [ - (RagasMetric.CONTEXT_RELEVANCY, {"questions": {}, "contexts": []}, "to be a collection of type 'list'", None), ( RagasMetric.FAITHFULNESS, {"questions": [1], "contexts": [2], "responses": [3]}, @@ -256,12 +252,6 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): "large?", }, ), - ( - RagasMetric.CONTEXT_RELEVANCY, - {"questions": ["q8"], "contexts": [["c8"]]}, - [[(None, 1.0)]], - None, - ), ( RagasMetric.ANSWER_RELEVANCY, {"questions": ["q9"], "contexts": [["c9"]], "responses": ["r9"]}, @@ -293,6 +283,7 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para # This integration test validates the evaluator by running it against the # OpenAI API. It is parameterized by the metric, the inputs to the evaluator # and the metric parameters. +@pytest.mark.asyncio @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.parametrize( "metric, inputs, metric_params", @@ -337,7 +328,6 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para "large?", }, ), - (RagasMetric.CONTEXT_RELEVANCY, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS}, None), ( RagasMetric.ANSWER_RELEVANCY, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, diff --git a/integrations/unstructured/CHANGELOG.md b/integrations/unstructured/CHANGELOG.md new file mode 100644 index 000000000..847a295c1 --- /dev/null +++ b/integrations/unstructured/CHANGELOG.md @@ -0,0 +1,67 @@ +# Changelog + +## [integrations/unstructured-v0.4.1] - 2024-06-28 + +### 🚀 Features + +- Generate unstructured API docs (#350) +- *(unstructured)* Add element index as metadata (#382) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +### Unstructured + +- Review docstrings (#531) + +### Build + +- Add `psutil` dependency to Unstructured integration (#854) + +### Unstructured + +- Add missing from_dict method (#532) + +## [integrations/unstructured-v0.3.1] - 2024-02-05 + +### Unstructured + +- Fix metadata order mixed up (#336) + +## [integrations/unstructured-v0.3.0] - 2024-01-23 + +### 🐛 Bug Fixes + +- Fix license headers + +- Fix project urls (#96) + + + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + +### ⚙️ Miscellaneous Tasks + +- Pin unstructured api version (#105) + +### Feat + +- UnstructuredFileConverter meta field (#242) + + diff --git a/integrations/unstructured/pydoc/config.yml b/integrations/unstructured/pydoc/config.yml index 7179a2607..f2b4061a4 100644 --- a/integrations/unstructured/pydoc/config.yml +++ b/integrations/unstructured/pydoc/config.yml @@ -14,7 +14,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Unstructured integration for Haystack category_slug: integrations-api title: Unstructured diff --git a/integrations/unstructured/pyproject.toml b/integrations/unstructured/pyproject.toml index 9430b8732..b5de6c66a 100644 --- a/integrations/unstructured/pyproject.toml +++ b/integrations/unstructured/pyproject.toml @@ -7,27 +7,21 @@ name = "unstructured-fileconverter-haystack" dynamic = ["version"] description = 'Haystack 2.x component to convert files into Documents using the Unstructured API' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "unstructured", -] +dependencies = ["haystack-ai", "unstructured", "psutil"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/unstructured#readme" @@ -49,49 +43,30 @@ git_describe_command = 'git describe --tags --match="integrations/unstructured-v dependencies = [ "coverage[toml]>=6.5", "pytest", + "pytest-rerunfailures", "pytest-xdist", "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] +python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -137,9 +112,15 @@ ignore = [ # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", ] unfixable = [ # Don't touch unused imports @@ -163,25 +144,13 @@ parallel = true [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.pytest.ini_options] minversion = "6.0" -markers = [ - "unit: unit tests", - "integration: integration tests" -] +markers = ["unit: unit tests", "integration: integration tests"] [[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pytest.*", - "unstructured.*", -] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "unstructured.*"] ignore_missing_imports = true diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md new file mode 100644 index 000000000..bddde1b7d --- /dev/null +++ b/integrations/weaviate/CHANGELOG.md @@ -0,0 +1,67 @@ +# Changelog + +## [unreleased] + +### 🚀 Features + +- Add filter_policy to weaviate integration (#824) + +### 🐛 Bug Fixes + +- Weaviate filter error (#811) +- Fix connection to Weaviate Cloud Service (#624) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/weaviate-v2.1.0] - 2024-06-10 + +### 🚀 Features + +- Defer the database connection to when it's needed (#802) + +### 🐛 Bug Fixes + +- Weaviate schema class name conversion which preserves PascalCase (#707) + +## [integrations/weaviate-v2.0.0] - 2024-03-25 + +### 📚 Documentation + +- Disable-class-def (#556) +- Fix docstrings (#586) + +### Weaviate + +- Migrate from weaviate python client v3 to v4 (#463) + +## [integrations/weaviate-v1.0.2] - 2024-02-27 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Fix weaviate auth tests (#488) + + + +### 📚 Documentation + +- Update category slug (#442) + +### Weaviate + +- Make retrievers return dicts (#491) + +## [integrations/weaviate-v1.0.0] - 2024-02-15 + +### 🚀 Features + +- Generate weaviate API docs (#351) + +## [integrations/weaviate-v0.0.0] - 2024-01-10 + + diff --git a/integrations/weaviate/pydoc/config.yml b/integrations/weaviate/pydoc/config.yml index e62b21591..ab585ebb7 100644 --- a/integrations/weaviate/pydoc/config.yml +++ b/integrations/weaviate/pydoc/config.yml @@ -18,7 +18,7 @@ processors: - type: smart - type: crossref renderer: - type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer excerpt: Weaviate integration for Haystack category_slug: integrations-api title: Weaviate diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index b0a618505..14b60fe12 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -47,12 +47,14 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "ipython"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" cov-report = ["- coverage combine", "coverage report"] cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -63,7 +65,7 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff {args:.}", "black --check --diff {args:.}"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] all = ["style", "typing"] @@ -75,7 +77,7 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -select = [ +lint.select = [ "A", "ARG", "B", @@ -102,7 +104,7 @@ select = [ "W", "YTT", ] -ignore = [ +lint.ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` @@ -118,18 +120,18 @@ ignore = [ "PLR0913", "PLR0915", ] -unfixable = [ +lint.unfixable = [ # Don't touch unused imports "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["src"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] @@ -141,12 +143,8 @@ parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] -show_missing=true -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py index 34bfd0c7d..015ff6e67 100644 --- a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .bm25_retriever import WeaviateBM25Retriever from .embedding_retriever import WeaviateEmbeddingRetriever diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py index 6deef5eb6..fec0b81e6 100644 --- a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py @@ -1,6 +1,12 @@ -from typing import Any, Dict, List, Optional +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore @@ -26,6 +32,7 @@ def __init__( document_store: WeaviateDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Create a new instance of WeaviateBM25Retriever. @@ -36,10 +43,14 @@ def __init__( Custom filters applied when running the retriever :param top_k: Maximum number of documents to return + :param filter_policy: Policy to determine how filters are applied. """ self._document_store = document_store self._filters = filters or {} self._top_k = top_k + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) def to_dict(self) -> Dict[str, Any]: """ @@ -52,6 +63,7 @@ def to_dict(self) -> Dict[str, Any]: self, filters=self._filters, top_k=self._top_k, + filter_policy=self._filter_policy.value, document_store=self._document_store.to_dict(), ) @@ -68,6 +80,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever": data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -77,12 +94,14 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio :param query: The query text. - :param filters: - Filters to use when running the retriever. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: The maximum number of documents to return. """ - filters = filters or self._filters + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + top_k = top_k or self._top_k documents = self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k) return {"documents": documents} diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py index cdf578fee..8688b4145 100644 --- a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py @@ -1,6 +1,12 @@ -from typing import Any, Dict, List, Optional +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore @@ -18,6 +24,7 @@ def __init__( top_k: int = 10, distance: Optional[float] = None, certainty: Optional[float] = None, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Creates a new instance of WeaviateEmbeddingRetriever. @@ -32,6 +39,8 @@ def __init__( The maximum allowed distance between Documents' embeddings. :param certainty: Normalized distance between the result item and the search vector. + :param filter_policy: + Policy to determine how filters are applied. :raises ValueError: If both `distance` and `certainty` are provided. See https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables to learn more about @@ -46,6 +55,9 @@ def __init__( self._top_k = top_k self._distance = distance self._certainty = certainty + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) def to_dict(self) -> Dict[str, Any]: """ @@ -60,6 +72,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, distance=self._distance, certainty=self._certainty, + filter_policy=self._filter_policy.value, document_store=self._document_store.to_dict(), ) @@ -76,6 +89,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever": data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -92,8 +111,9 @@ def run( :param query_embedding: Embedding of the query. - :param filters: - Filters to use when running the retriever. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. :param top_k: The maximum number of documents to return. :param distance: @@ -105,7 +125,7 @@ def run( See https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables to learn more about `distance` and `certainty` parameters. """ - filters = filters or self._filters + filters = apply_filter_policy(self._filter_policy, self._filters, filters) top_k = top_k or self._top_k distance = distance or self._distance diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py index a2201f0a5..803274aa4 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Dict from dateutil import parser diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index 4c3898130..33bc30159 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from enum import Enum diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index ec66e07c3..82088dd89 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -12,6 +12,7 @@ from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types.policy import DuplicatePolicy +from haystack.utils.filters import convert import weaviate from weaviate.collections.classes.data import DataObject @@ -139,47 +140,83 @@ def __init__( :param grpc_secure: Whether to use a secure channel for the underlying gRPC API. """ - # proxies, timeout_config, trust_env are part of additional_config now - # startup_period has been removed - self._client = weaviate.WeaviateClient( - connection_params=( - weaviate.connect.base.ConnectionParams.from_url(url=url, grpc_port=grpc_port, grpc_secure=grpc_secure) - if url - else None - ), - auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None, - additional_config=additional_config, - additional_headers=additional_headers, - embedded_options=embedded_options, - skip_init_checks=False, + self._url = url + self._auth_client_secret = auth_client_secret + self._additional_headers = additional_headers + self._embedded_options = embedded_options + self._additional_config = additional_config + self._grpc_port = grpc_port + self._grpc_secure = grpc_secure + self._client = None + self._collection = None + # Store the connection settings dictionary + self._collection_settings = collection_settings or { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": DOCUMENT_COLLECTION_PROPERTIES, + } + self._clean_connection_settings() + + def _clean_connection_settings(self): + # Set the class if not set + _class_name = self._collection_settings.get("class", "Default") + _class_name = _class_name[0].upper() + _class_name[1:] + self._collection_settings["class"] = _class_name + # Set the properties if they're not set + self._collection_settings["properties"] = self._collection_settings.get( + "properties", DOCUMENT_COLLECTION_PROPERTIES ) + + @property + def client(self): + if self._client: + return self._client + + if self._url and self._url.startswith("http") and self._url.endswith(".weaviate.network"): + # We use this utility function instead of using WeaviateClient directly like in other cases + # otherwise we'd have to parse the URL to get some information about the connection. + # This utility function does all that for us. + self._client = weaviate.connect_to_wcs( + self._url, + auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, + headers=self._additional_headers, + additional_config=self._additional_config, + ) + else: + # proxies, timeout_config, trust_env are part of additional_config now + # startup_period has been removed + self._client = weaviate.WeaviateClient( + connection_params=( + weaviate.connect.base.ConnectionParams.from_url( + url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure + ) + if self._url + else None + ), + auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, + additional_config=self._additional_config, + additional_headers=self._additional_headers, + embedded_options=self._embedded_options, + skip_init_checks=False, + ) + self._client.connect() # Test connection, it will raise an exception if it fails. - self._client.collections._get_all(simple=True) + self._client.collections.list_all(simple=True) + if not self._client.collections.exists(self._collection_settings["class"]): + self._client.collections.create_from_dict(self._collection_settings) - if collection_settings is None: - collection_settings = { - "class": "Default", - "invertedIndexConfig": {"indexNullState": True}, - "properties": DOCUMENT_COLLECTION_PROPERTIES, - } - else: - # Set the class if not set - collection_settings["class"] = collection_settings.get("class", "default").capitalize() - # Set the properties if they're not set - collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES) + return self._client - if not self._client.collections.exists(collection_settings["class"]): - self._client.collections.create_from_dict(collection_settings) + @property + def collection(self): + if self._collection: + return self._collection - self._url = url - self._collection_settings = collection_settings - self._auth_client_secret = auth_client_secret - self._additional_headers = additional_headers - self._embedded_options = embedded_options - self._additional_config = additional_config - self._collection = self._client.collections.get(collection_settings["class"]) + client = self.client + self._collection = client.collections.get(self._collection_settings["class"]) + return self._collection def to_dict(self) -> Dict[str, Any]: """ @@ -228,7 +265,7 @@ def count_documents(self) -> int: """ Returns the number of documents present in the DocumentStore. """ - total = self._collection.aggregate.over_all(total_count=True).total_count + total = self.collection.aggregate.over_all(total_count=True).total_count return total if total else 0 def _to_data_object(self, document: Document) -> Dict[str, Any]: @@ -300,16 +337,16 @@ def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document: return Document.from_dict(document_data) def _query(self) -> List[Dict[str, Any]]: - properties = [p.name for p in self._collection.config.get().properties] + properties = [p.name for p in self.collection.config.get().properties] try: - result = self._collection.iterator(include_vector=True, return_properties=properties) + result = self.collection.iterator(include_vector=True, return_properties=properties) except weaviate.exceptions.WeaviateQueryError as e: msg = f"Failed to query documents in Weaviate. Error: {e.message}" raise DocumentStoreError(msg) from e return result def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: - properties = [p.name for p in self._collection.config.get().properties] + properties = [p.name for p in self.collection.config.get().properties] # When querying with filters we need to paginate using limit and offset as using # a cursor with after is not possible. See the official docs: # https://weaviate.io/developers/weaviate/api/graphql/additional-operators#cursor-with-after @@ -325,7 +362,7 @@ def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: # Keep querying until we get all documents matching the filters while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT: try: - partial_result = self._collection.query.fetch_objects( + partial_result = self.collection.query.fetch_objects( filters=convert_filters(filters), include_vector=True, limit=DEFAULT_QUERY_LIMIT, @@ -349,6 +386,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: The filters to apply to the document list. :returns: A list of Documents that match the given filters. """ + if filters and "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + result = [] if filters: result = self._query_with_filters(filters) @@ -363,7 +403,7 @@ def _batch_write(self, documents: List[Document]) -> int: Raises in case of errors. """ - with self._client.batch.dynamic() as batch: + with self.client.batch.dynamic() as batch: for doc in documents: if not isinstance(doc, Document): msg = f"Expected a Document, got '{type(doc)}' instead." @@ -371,11 +411,11 @@ def _batch_write(self, documents: List[Document]) -> int: batch.add_object( properties=self._to_data_object(doc), - collection=self._collection.name, + collection=self.collection.name, uuid=generate_uuid5(doc.id), vector=doc.embedding, ) - if failed_objects := self._client.batch.failed_objects: + if failed_objects := self.client.batch.failed_objects: # We fallback to use the UUID if the _original_id is not present, this is just to be mapped_objects = {} for obj in failed_objects: @@ -411,12 +451,12 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and self._collection.data.exists(uuid=generate_uuid5(doc.id)): + if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, we skip it continue try: - self._collection.data.insert( + self.collection.data.insert( uuid=generate_uuid5(doc.id), properties=self._to_data_object(doc), vector=doc.embedding, @@ -452,13 +492,13 @@ def delete_documents(self, document_ids: List[str]) -> None: :param document_ids: The object_ids to delete. """ weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids] - self._collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) + self.collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) def _bm25_retrieval( self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ) -> List[Document]: - properties = [p.name for p in self._collection.config.get().properties] - result = self._collection.query.bm25( + properties = [p.name for p in self.collection.config.get().properties] + result = self.collection.query.bm25( query=query, filters=convert_filters(filters) if filters else None, limit=top_k, @@ -482,8 +522,8 @@ def _embedding_retrieval( msg = "Can't use 'distance' and 'certainty' parameters together" raise ValueError(msg) - properties = [p.name for p in self._collection.config.get().properties] - result = self._collection.query.near_vector( + properties = [p.name for p in self.collection.config.get().properties] + result = self.collection.query.near_vector( near_vector=query_embedding, distance=distance, certainty=certainty, diff --git a/integrations/weaviate/tests/conftest.py b/integrations/weaviate/tests/conftest.py index ed1002409..c08ebbd38 100644 --- a/integrations/weaviate/tests/conftest.py +++ b/integrations/weaviate/tests/conftest.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from pathlib import Path import pytest diff --git a/integrations/weaviate/tests/test_auth.py b/integrations/weaviate/tests/test_auth.py index b653d9105..3ad75e206 100644 --- a/integrations/weaviate/tests/test_auth.py +++ b/integrations/weaviate/tests/test_auth.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from haystack_integrations.document_stores.weaviate.auth import ( AuthApiKey, AuthBearerToken, diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py index 23b7c8f92..3720daa85 100644 --- a/integrations/weaviate/tests/test_bm25_retriever.py +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -1,5 +1,11 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from unittest.mock import Mock, patch +import pytest +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.weaviate import WeaviateBM25Retriever from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore @@ -10,6 +16,13 @@ def test_init_default(): assert retriever._document_store == mock_document_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = WeaviateBM25Retriever(document_store=mock_document_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + WeaviateBM25Retriever(document_store=mock_document_store, filter_policy="keep_all") @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") @@ -21,6 +34,7 @@ def test_to_dict(_mock_weaviate): "init_parameters": { "filters": {}, "top_k": 10, + "filter_policy": "replace", "document_store": { "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { @@ -49,6 +63,45 @@ def test_to_dict(_mock_weaviate): @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") def test_from_dict(_mock_weaviate): + retriever = WeaviateBM25Retriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "filter_policy": "replace", + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "additional_headers": None, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict_no_filter_policy(_mock_weaviate): retriever = WeaviateBM25Retriever.from_dict( { "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", @@ -83,6 +136,7 @@ def test_from_dict(_mock_weaviate): assert retriever._document_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE @patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index cc76923f6..068212686 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,4 +1,9 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import base64 +import os import random from typing import List from unittest.mock import MagicMock, patch @@ -16,6 +21,7 @@ FilterDocumentsTest, WriteDocumentsTest, ) +from haystack.utils.auth import Secret from haystack_integrations.document_stores.weaviate.auth import AuthApiKey from haystack_integrations.document_stores.weaviate.document_store import ( DOCUMENT_COLLECTION_PROPERTIES, @@ -26,8 +32,6 @@ from numpy import float32 as np_float32 from pandas import DataFrame from weaviate.collections.classes.data import DataObject - -# from weaviate.auth import AuthApiKey as WeaviateAuthApiKey from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout from weaviate.embedded import ( DEFAULT_BINARY_PATH, @@ -38,6 +42,12 @@ ) +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") +def test_init_is_lazy(_mock_client): + _ = WeaviateDocumentStore() + _mock_client.assert_not_called() + + @pytest.mark.integration class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): @pytest.fixture @@ -57,7 +67,7 @@ def document_store(self, request) -> WeaviateDocumentStore: collection_settings=collection_settings, ) yield store - store._client.collections.delete(collection_settings["class"]) + store.client.collections.delete(collection_settings["class"]) @pytest.fixture def filterable_docs(self) -> List[Document]: @@ -150,12 +160,12 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert received_meta.get(key) == expected_meta.get(key) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") - def test_init(self, mock_weaviate_client_class, monkeypatch): + def test_connection(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() mock_client.collections.exists.return_value = False mock_weaviate_client_class.return_value = mock_client monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") - WeaviateDocumentStore( + ds = WeaviateDocumentStore( collection_settings={"class": "My_collection"}, auth_client_secret=AuthApiKey(), additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -170,8 +180,11 @@ def test_init(self, mock_weaviate_client_class, monkeypatch): ), ) - # Verify client is created with correct parameters + # Trigger the actual database connection by accessing the `client` property so we + # can assert the setup was good + _ = ds.client + # Verify client is created with correct parameters mock_weaviate_client_class.assert_called_once_with( auth_client_secret=AuthApiKey().resolve_value(), connection_params=None, @@ -646,6 +659,15 @@ def test_embedding_retrieval_with_distance_and_certainty(self, document_store): with pytest.raises(ValueError): document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) + def test_filter_documents_with_legacy_filters(self, document_store): + docs = [] + for index in range(10): + docs.append(Document(content="This is some content", meta={"index": index})) + document_store.write_documents(docs) + result = document_store.filter_documents({"content": {"$eq": "This is some content"}}) + + assert len(result) == 10 + def test_filter_documents_below_default_limit(self, document_store): docs = [] for index in range(9998): @@ -664,3 +686,39 @@ def test_filter_documents_over_default_limit(self, document_store): document_store.write_documents(docs) with pytest.raises(DocumentStoreError): document_store.filter_documents({"field": "content", "operator": "==", "value": "This is some content"}) + + def test_schema_class_name_conversion_preserves_pascal_case(self): + collection_settings = {"class": "CaseDocument"} + doc_score = WeaviateDocumentStore( + url="http://localhost:8080", + collection_settings=collection_settings, + ) + assert doc_score._collection_settings["class"] == "CaseDocument" + + collection_settings = {"class": "lower_case_name"} + doc_score = WeaviateDocumentStore( + url="http://localhost:8080", + collection_settings=collection_settings, + ) + assert doc_score._collection_settings["class"] == "Lower_case_name" + + @pytest.mark.skipif( + not os.environ.get("WEAVIATE_API_KEY", None) and not os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL", None), + reason="Both WEAVIATE_API_KEY and WEAVIATE_CLOUD_CLUSTER_URL are not set. Skipping test.", + ) + def test_connect_to_weaviate_cloud(self): + document_store = WeaviateDocumentStore( + url=os.environ.get("WEAVIATE_CLOUD_CLUSTER_URL"), + auth_client_secret=AuthApiKey(api_key=Secret.from_env_var("WEAVIATE_API_KEY")), + ) + assert document_store.client + + def test_connect_to_local(self): + document_store = WeaviateDocumentStore( + url="http://localhost:8080", + ) + assert document_store.client + + def test_connect_to_embedded(self): + document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions()) + assert document_store.client diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py index c7c147ba5..13f214dd1 100644 --- a/integrations/weaviate/tests/test_embedding_retriever.py +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -1,6 +1,11 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from unittest.mock import Mock, patch import pytest +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.weaviate import WeaviateEmbeddingRetriever from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore @@ -11,9 +16,16 @@ def test_init_default(): assert retriever._document_store == mock_document_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._distance is None assert retriever._certainty is None + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + WeaviateEmbeddingRetriever(document_store=mock_document_store, filter_policy="keep_all") + def test_init_with_distance_and_certainty(): mock_document_store = Mock(spec=WeaviateDocumentStore) @@ -30,6 +42,7 @@ def test_to_dict(_mock_weaviate): "init_parameters": { "filters": {}, "top_k": 10, + "filter_policy": "replace", "distance": None, "certainty": None, "document_store": { @@ -60,6 +73,49 @@ def test_to_dict(_mock_weaviate): @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") def test_from_dict(_mock_weaviate): + retriever = WeaviateEmbeddingRetriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "filter_policy": "replace", + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "additional_headers": None, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict_no_filter_policy(_mock_weaviate): retriever = WeaviateEmbeddingRetriever.from_dict( { "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501 @@ -98,6 +154,7 @@ def test_from_dict(_mock_weaviate): assert retriever._top_k == 10 assert retriever._distance is None assert retriever._certainty is None + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE @patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") diff --git a/integrations/weaviate/tests/test_filters.py b/integrations/weaviate/tests/test_filters.py index c32d69e2f..26997a05c 100644 --- a/integrations/weaviate/tests/test_filters.py +++ b/integrations/weaviate/tests/test_filters.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from haystack_integrations.document_stores.weaviate._filters import _invert_condition diff --git a/show_unreleased.sh b/show_unreleased.sh new file mode 100755 index 000000000..17f18f22c --- /dev/null +++ b/show_unreleased.sh @@ -0,0 +1,9 @@ +#!/bin/bash +INTEGRATION=$1 +if [ -z "${INTEGRATION}" ] ; then + echo "Please provide the name of an integration, for example:" + echo "./$(basename $0) chroma" + exit 1 +fi +LATEST_TAG=$(git tag -l --sort=-creatordate "integrations/${INTEGRATION}-v*" | head -n 1) +git --no-pager diff $LATEST_TAG..main integrations/${INTEGRATION}