diff --git a/.github/utils/pyproject_to_requirements.py b/.github/utils/pyproject_to_requirements.py new file mode 100644 index 000000000..48f07ffdb --- /dev/null +++ b/.github/utils/pyproject_to_requirements.py @@ -0,0 +1,26 @@ +import argparse +import sys +from pathlib import Path +import toml + +def main(pyproject_path: Path, exclude_optional_dependencies: bool = False): + content = toml.load(pyproject_path) + deps = set(content["project"]["dependencies"]) + + if not exclude_optional_dependencies: + optional_deps = content["project"].get("optional-dependencies", {}) + for dep_list in optional_deps.values(): + deps.update(dep_list) + + print("\n".join(sorted(deps))) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="pyproject_to_requirements.py", + description="Convert pyproject.toml to requirements.txt" + ) + parser.add_argument("pyproject_path", type=Path, help="Path to pyproject.toml file") + parser.add_argument("--exclude-optional-dependencies", action="store_true", help="Exclude optional dependencies") + + args = parser.parse_args() + main(args.pyproject_path, args.exclude_optional_dependencies) diff --git a/.github/workflows/CI_license_compliance.yml b/.github/workflows/CI_license_compliance.yml new file mode 100644 index 000000000..fc28706df --- /dev/null +++ b/.github/workflows/CI_license_compliance.yml @@ -0,0 +1,95 @@ +name: Core / License Compliance + +on: + pull_request: + paths: + - "integrations/**/pyproject.toml" + # Since we test PRs, there is no need to run the workflow at each + # merge on `main`. Let's use a cron job instead. + schedule: + - cron: "0 0 * * *" # every day at midnight + +env: + CORE_DATADOG_API_KEY: ${{ secrets.CORE_DATADOG_API_KEY }} + PYTHON_VERSION: "3.10" + EXCLUDE_PACKAGES: "(?i)^(deepeval|cohere|fastembed|ragas|tqdm|psycopg).*" + + # Exclusions must be explicitly motivated + # + # - deepeval is Apache 2.0 but the license is not available on PyPI + # - cohere is MIT but the license is not available on PyPI + # - fastembed is Apache 2.0 but the license on PyPI is unclear ("Other/Proprietary License (Apache License)") + # - ragas is Apache 2.0 but the license is not available on PyPI + + # - tqdm is MLP but there are no better alternatives + # - psycopg is LGPL-3.0 but FOSSA is fine with it + +jobs: + license_check_direct: + name: Direct dependencies only + env: + REQUIREMENTS_FILE: requirements_direct.txt + runs-on: ubuntu-latest + steps: + - name: Checkout the code + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "${{ env.PYTHON_VERSION }}" + + - name: Get changed pyproject files (for pull requests only) + if: ${{ github.event_name == 'pull_request' }} + id: changed-files + uses: tj-actions/changed-files@v45 + with: + files: | + integrations/**/pyproject.toml + + - name: Get direct dependencies from pyproject.toml files + run: | + pip install toml + + # Determine the list of pyproject.toml files to process + if [ "${{ github.event_name }}" = "schedule" ]; then + echo "Scheduled run: processing all pyproject.toml files..." + FILES=$(find integrations -type f -name 'pyproject.toml') + else + echo "Pull request: processing changed pyproject.toml files..." + FILES="${{ steps.changed-files.outputs.all_changed_files }}" + fi + + for file in $FILES; do + python .github/utils/pyproject_to_requirements.py $file >> ${{ env.REQUIREMENTS_FILE }} + echo "" >> ${{ env.REQUIREMENTS_FILE }} + done + + - name: Check Licenses + id: license_check_report + uses: pilosus/action-pip-license-checker@v2 + with: + github-token: ${{ secrets.GH_ACCESS_TOKEN }} + requirements: ${{ env.REQUIREMENTS_FILE }} + fail: "Copyleft,Other,Error" + exclude: "${{ env.EXCLUDE_PACKAGES }}" + + # We keep the license inventory on FOSSA + - name: Send license report to Fossa + uses: fossas/fossa-action@v1.4.0 + continue-on-error: true # not critical + with: + api-key: ${{ secrets.FOSSA_LICENSE_SCAN_TOKEN }} + + - name: Print report + if: ${{ always() }} + run: echo "${{ steps.license_check_report.outputs.report }}" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations license compliance nightly failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} + \ No newline at end of file diff --git a/README.md b/README.md index 010ca1763..c4178184b 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta ## Inventory +[![License Compliance](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/CI_license_compliance.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/CI_license_compliance.yml) + | Package | Type | PyPi Package | Status | |----------------------------------------------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [amazon-bedrock-haystack](integrations/amazon_bedrock/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index cbdd97046..0fc7ce0ab 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -2,11 +2,24 @@ ## [unreleased] +### ๐Ÿ› Bug Fixes + +- Remove the use of deprecated gemini models (#1032) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +### Docs + +- Update GeminiGenerator docstrings (#964) +- Update GoogleChatGenerator docstrings (#962) + ## [integrations/google_ai-v1.1.0] - 2024-06-05 ### ๐Ÿ› Bug Fixes diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index dd065af4b..cf0005f39 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -1,16 +1,16 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai from google.ai.generativelanguage import Content, Part from google.ai.generativelanguage import Tool as ToolProto from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool +from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory, Tool from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream +from haystack.dataclasses import ByteStream, StreamingChunk from haystack.dataclasses.chat_message import ChatMessage, ChatRole -from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) @@ -21,10 +21,7 @@ class GoogleAIGeminiChatGenerator: Completes chats using multimodal Gemini models through Google AI Studio. It uses the [`ChatMessage`](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) - dataclass to interact with the model. You can use the following models: - - gemini-pro - - gemini-ultra - - gemini-pro-vision + dataclass to interact with the model. ### Usage example @@ -103,27 +100,20 @@ def __init__( self, *, api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), # noqa: B008 - model: str = "gemini-pro-vision", + model: str = "gemini-1.5-flash", generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Initializes a `GoogleAIGeminiChatGenerator` instance. To get an API key, visit: https://makersuite.google.com - It supports the following models: - * `gemini-pro` - * `gemini-pro-vision` - * `gemini-ultra` - :param api_key: Google AI Studio API key. To get a key, see [Google AI Studio](https://makersuite.google.com). - :param model: Name of the model to use. Supported models are: - - gemini-pro - - gemini-ultra - - gemini-pro-vision + :param model: Name of the model to use. For available models, see https://ai.google.dev/gemini-api/docs/models/gemini. :param generation_config: The generation configuration to use. This can either be a `GenerationConfig` object or a dictionary of parameters. For available parameters, see @@ -132,6 +122,8 @@ def __init__( A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ genai.configure(api_key=api_key.resolve_value()) @@ -142,6 +134,7 @@ def __init__( self._safety_settings = safety_settings self._tools = tools self._model = GenerativeModel(self._model_name, tools=self._tools) + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -162,6 +155,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None + data = default_to_dict( self, api_key=self._api_key.to_dict(), @@ -169,6 +164,7 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [] @@ -213,6 +209,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": data["init_parameters"]["safety_settings"] = { HarmCategory(k): HarmBlockThreshold(v) for k, v in safety_settings.items() } + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -274,16 +272,23 @@ def _message_to_content(self, message: ChatMessage) -> Content: return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage]): + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates text based on the provided messages. :param messages: A list of `ChatMessage` instances, representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - `replies`: A list containing the generated responses as `ChatMessage` instances. """ + streaming_callback = streaming_callback or self._streaming_callback history = [self._message_to_content(m) for m in messages[:-1]] session = self._model.start_chat(history=history) @@ -292,10 +297,22 @@ def run(self, messages: List[ChatMessage]): content=new_message, generation_config=self._generation_config, safety_settings=self._safety_settings, + stream=streaming_callback is not None, ) + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMessage]: + """ + Extracts the responses from the Google AI response. + + :param response_body: The response from Google AI request. + :returns: The extracted responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part.text != "": replies.append(ChatMessage.from_system(part.text)) @@ -307,5 +324,23 @@ def run(self, messages: List[ChatMessage]): name=part.function_call.name, ) ) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: GenerateContentResponse, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + """ + Extracts the responses from the Google AI streaming response. + + :param stream: The streaming response from the Google AI request. + :param streaming_callback: The handler for the streaming response. + :returns: The extracted response with the content of all streaming chunks. + """ + responses = [] + for chunk in stream: + content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" + streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) + responses.append(content) + + combined_response = "".join(responses).lstrip() + return [ChatMessage.from_system(content=combined_response)] diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 07277e55a..218e16c4c 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -1,15 +1,15 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai from google.ai.generativelanguage import Content, Part, Tool from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory from haystack.core.component import component from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream -from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.dataclasses import ByteStream, StreamingChunk +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ class GoogleAIGeminiGenerator: for url in URLS ] - gemini = GoogleAIGeminiGenerator(model="gemini-pro-vision", api_key=Secret.from_token("")) + gemini = GoogleAIGeminiGenerator(model="gemini-1.5-flash", api_key=Secret.from_token("")) result = gemini.run(parts = ["What can you tell me about this robots?", *images]) for answer in result["replies"]: print(answer) @@ -66,23 +66,19 @@ def __init__( self, *, api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), # noqa: B008 - model: str = "gemini-pro-vision", + model: str = "gemini-1.5-flash", generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Initializes a `GoogleAIGeminiGenerator` instance. To get an API key, visit: https://makersuite.google.com - It supports the following models: - * `gemini-pro` - * `gemini-pro-vision` - * `gemini-ultra` - :param api_key: Google AI Studio API key. - :param model: Name of the model to use. + :param model: Name of the model to use. For available models, see https://ai.google.dev/gemini-api/docs/models/gemini :param generation_config: The generation configuration to use. This can either be a `GenerationConfig` object or a dictionary of parameters. For available parameters, see @@ -91,6 +87,8 @@ def __init__( A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ genai.configure(api_key=api_key.resolve_value()) @@ -100,6 +98,7 @@ def __init__( self._safety_settings = safety_settings self._tools = tools self._model = GenerativeModel(self._model_name, tools=self._tools) + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -120,6 +119,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None data = default_to_dict( self, api_key=self._api_key.to_dict(), @@ -127,6 +127,7 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] @@ -156,6 +157,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": data["init_parameters"]["safety_settings"] = { HarmCategory(k): HarmBlockThreshold(v) for k, v in safety_settings.items() } + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -176,28 +179,45 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) @component.output_types(replies=List[Union[str, Dict[str, str]]]) - def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + def run( + self, + parts: Variadic[Union[str, ByteStream, Part]], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates text based on the given input parts. :param parts: A heterogeneous list of strings, `ByteStream` or `Part` objects. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - `replies`: A list of strings or dictionaries with function calls. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback converted_parts = [self._convert_part(p) for p in parts] - contents = [Content(parts=converted_parts, role="user")] res = self._model.generate_content( contents=contents, generation_config=self._generation_config, safety_settings=self._safety_settings, + stream=streaming_callback is not None, ) self._model.start_chat() + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerateContentResponse) -> List[str]: + """ + Extracts the responses from the Google AI request. + :param response_body: The response body from the Google AI request. + :returns: A list of string responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part.text != "": replies.append(part.text) @@ -207,5 +227,23 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): "args": dict(part.function_call.args.items()), } replies.append(function_call) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: GenerateContentResponse, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[str]: + """ + Extracts the responses from the Google AI streaming response. + :param stream: The streaming response from the Google AI request. + :param streaming_callback: The handler for the streaming response. + :returns: A list of string responses. + """ + + responses = [] + for chunk in stream: + content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" + streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) + responses.append(content) + + combined_response = ["".join(responses).lstrip()] + return combined_response diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 9b3124eab..1a910b977 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -4,10 +4,30 @@ import pytest from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool +from haystack.dataclasses import StreamingChunk from haystack.dataclasses.chat_message import ChatMessage from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -21,26 +41,7 @@ def test_init(monkeypatch): top_k=0.5, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) with patch( "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" ) as mock_genai_configure: @@ -50,7 +51,7 @@ def test_init(monkeypatch): tools=[tool], ) mock_genai_configure.assert_called_once_with(api_key="test") - assert gemini._model_name == "gemini-pro-vision" + assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings assert gemini._tools == [tool] @@ -60,6 +61,24 @@ def test_init(monkeypatch): def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator() + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +def test_to_dict_with_param(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -69,26 +88,7 @@ def test_to_dict(monkeypatch): top_k=2, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator( @@ -100,7 +100,7 @@ def test_to_dict(monkeypatch): "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", "init_parameters": { "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gemini-pro-vision", + "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -110,6 +110,7 @@ def test_to_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -128,7 +129,32 @@ def test_from_dict(monkeypatch): "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", "init_parameters": { "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gemini-pro-vision", + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config is None + assert gemini._safety_settings is None + assert gemini._tools is None + assert isinstance(gemini._model, GenerativeModel) + + +def test_from_dict_with_param(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -138,6 +164,7 @@ def test_from_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -147,7 +174,7 @@ def test_from_dict(monkeypatch): } ) - assert gemini._model_name == "gemini-pro-vision" + assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -198,6 +225,33 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert len(res["replies"]) > 0 +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") +def test_run_with_streaming_callback(): + streaming_callback_called = False + + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 + return {"weather": "sunny", "temperature": 21.8, "unit": unit} + + get_current_weather_func = FunctionDeclaration.from_function( + get_current_weather, + descriptions={ + "location": "The city and state, e.g. San Francisco, CA", + "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", + }, + ) + + tool = Tool(function_declarations=[get_current_weather_func]) + gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) + messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + res = gemini_chat.run(messages=messages) + assert len(res["replies"]) > 0 + assert streaming_callback_called + + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 35c7d196b..7206b7a43 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -5,9 +5,29 @@ from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory +from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -48,7 +68,7 @@ def test_init(monkeypatch): tools=[tool], ) mock_genai_configure.assert_called_once_with(api_key="test") - assert gemini._model_name == "gemini-pro-vision" + assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings assert gemini._tools == [tool] @@ -58,6 +78,24 @@ def test_init(monkeypatch): def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): + gemini = GoogleAIGeminiGenerator() + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +def test_to_dict_with_param(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -97,7 +135,7 @@ def test_to_dict(monkeypatch): assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", "init_parameters": { - "model": "gemini-pro-vision", + "model": "gemini-1.5-flash", "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, "generation_config": { "temperature": 0.5, @@ -108,6 +146,7 @@ def test_to_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -117,7 +156,7 @@ def test_to_dict(monkeypatch): } -def test_from_dict(monkeypatch): +def test_from_dict_with_param(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): @@ -125,7 +164,7 @@ def test_from_dict(monkeypatch): { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", "init_parameters": { - "model": "gemini-pro-vision", + "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -135,6 +174,7 @@ def test_from_dict(monkeypatch): "stop_sequences": ["stop"], }, "safety_settings": {10: 3}, + "streaming_callback": None, "tools": [ b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" @@ -144,7 +184,7 @@ def test_from_dict(monkeypatch): } ) - assert gemini._model_name == "gemini-pro-vision" + assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -154,33 +194,49 @@ def test_from_dict(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [ - Tool( - function_declarations=[ - FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], + assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] + assert isinstance(gemini._model, GenerativeModel) + + +def test_from_dict(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): + gemini = GoogleAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], }, - ) - ] + "safety_settings": {10: 3}, + "streaming_callback": None, + "tools": [ + b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" + b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" + b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" + ], + }, + } ) - ] + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) @@ -189,3 +245,17 @@ def test_run(): gemini = GoogleAIGeminiGenerator(model="gemini-pro") res = gemini.run("Tell me something cool") assert len(res["replies"]) > 0 + + +@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") +def test_run_with_streaming_callback(): + streaming_callback_called = False + + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + gemini = GoogleAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback) + res = gemini.run("Tell me something cool") + assert len(res["replies"]) > 0 + assert streaming_callback_called diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index f08a69b5f..893710121 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -1,15 +1,17 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union -import vertexai from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses import StreamingChunk from haystack.dataclasses.byte_stream import ByteStream from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.utils import deserialize_callable, serialize_callable +from vertexai import init as vertexai_init from vertexai.preview.generative_models import ( Content, - FunctionDeclaration, GenerationConfig, + GenerationResponse, GenerativeModel, HarmBlockThreshold, HarmCategory, @@ -25,9 +27,6 @@ class VertexAIGeminiChatGenerator: """ `VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models. - `VertexAIGeminiChatGenerator` supports both `gemini-pro` and `gemini-pro-vision` models. - Prompting with images requires `gemini-pro-vision`. Function calling, instead, requires `gemini-pro`. - Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). @@ -49,12 +48,13 @@ class VertexAIGeminiChatGenerator: def __init__( self, *, - model: str = "gemini-pro", + model: str = "gemini-1.5-flash", project_id: str, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ `VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models. @@ -63,7 +63,7 @@ def __init__( For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use, defaults to "gemini-pro-vision". + :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. :param generation_config: Configuration for the generation process. @@ -76,10 +76,13 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) the list of supported arguments. + :param streaming_callback: A callback function that is called when a new token is received from + the stream. The callback function accepts StreamingChunk as an argument. + """ # Login to GCP. This will fail if user has not set up their gcloud SDK - vertexai.init(project=project_id, location=location) + vertexai_init(project=project_id, location=location) self._model_name = model self._project_id = project_id @@ -89,18 +92,7 @@ def __init__( self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools - - def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: - return { - "name": function._raw_function_declaration.name, - "parameters": function._raw_function_declaration.parameters, - "description": function._raw_function_declaration.description, - } - - def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: - return { - "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], - } + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -121,6 +113,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None + data = default_to_dict( self, model=self._model_name, @@ -129,9 +123,10 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -150,7 +145,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator": data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -195,13 +191,21 @@ def _message_to_content(self, message: ChatMessage) -> Content: return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage]): + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """Prompts Google Vertex AI Gemini model to generate a response to a list of messages. :param messages: The last message is the prompt, the rest are the history. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary with the following keys: - `replies`: A list of ChatMessage objects representing the model's replies. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback + history = [self._message_to_content(m) for m in messages[:-1]] session = self._model.start_chat(history=history) @@ -211,10 +215,22 @@ def run(self, messages: List[ChatMessage]): generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + stream=streaming_callback is not None, ) + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: + """ + Extracts the responses from the Vertex AI response. + + :param response_body: The response from Vertex AI request. + :returns: The extracted responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(ChatMessage.from_system(part.text)) @@ -226,5 +242,23 @@ def run(self, messages: List[ChatMessage]): name=part.function_call.name, ) ) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + """ + Extracts the responses from the Vertex AI streaming response. + + :param stream: The streaming response from the Vertex AI request. + :param streaming_callback: The handler for the streaming response. + :returns: The extracted response with the content of all streaming chunks. + """ + responses = [] + for chunk in stream: + streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) + streaming_callback(streaming_chunk) + responses.append(streaming_chunk.content) + + combined_response = "".join(responses).lstrip() + return [ChatMessage.from_system(content=combined_response)] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 8a288a315..11592671f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -1,15 +1,16 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union -import vertexai from haystack.core.component import component from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream -from vertexai.preview.generative_models import ( +from haystack.dataclasses import ByteStream, StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable +from vertexai import init as vertexai_init +from vertexai.generative_models import ( Content, - FunctionDeclaration, GenerationConfig, + GenerationResponse, GenerativeModel, HarmBlockThreshold, HarmCategory, @@ -25,9 +26,6 @@ class VertexAIGeminiGenerator: """ `VertexAIGeminiGenerator` enables text generation using Google Gemini models. - `VertexAIGeminiGenerator` supports both `gemini-pro` and `gemini-pro-vision` models. - Prompting with images requires `gemini-pro-vision`. Function calling, instead, requires `gemini-pro`. - Usage example: ```python from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator @@ -54,12 +52,13 @@ class VertexAIGeminiGenerator: def __init__( self, *, - model: str = "gemini-pro-vision", + model: str = "gemini-1.5-flash", project_id: str, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Multi-modal generator using Gemini model via Google Vertex AI. @@ -68,7 +67,7 @@ def __init__( For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). :param project_id: ID of the GCP project to use. - :param model: Name of the model to use. + :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. :param generation_config: The generation config to use. Can either be a [`GenerationConfig`](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.GenerationConfig) @@ -87,10 +86,12 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) the list of supported arguments. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ # Login to GCP. This will fail if user has not set up their gcloud SDK - vertexai.init(project=project_id, location=location) + vertexai_init(project=project_id, location=location) self._model_name = model self._project_id = project_id @@ -100,18 +101,7 @@ def __init__( self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools - - def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: - return { - "name": function._raw_function_declaration.name, - "parameters": function._raw_function_declaration.parameters, - "description": function._raw_function_declaration.description, - } - - def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: - return { - "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], - } + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -132,6 +122,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None data = default_to_dict( self, model=self._model_name, @@ -140,9 +132,10 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -161,7 +154,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -176,14 +170,21 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) @component.output_types(replies=List[Union[str, Dict[str, str]]]) - def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + def run( + self, + parts: Variadic[Union[str, ByteStream, Part]], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates content using the Gemini model. :param parts: Prompt for the model. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary with the following keys: - `replies`: A list of generated content. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] @@ -192,10 +193,23 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + stream=streaming_callback is not None, ) self._model.start_chat() + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerationResponse) -> List[str]: + """ + Extracts the responses from the Vertex AI response. + + :param response_body: The response body from the Vertex AI request. + + :returns: A list of string responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) @@ -205,5 +219,24 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): "args": dict(part.function_call.args.items()), } replies.append(function_call) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] + ) -> List[str]: + """ + Extracts the responses from the Vertex AI streaming response. + + :param stream: The streaming response from the Vertex AI request. + :param streaming_callback: The handler for the streaming response. + :returns: A list of string responses. + """ + streaming_chunks: List[StreamingChunk] = [] + + for chunk in stream: + streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) + streaming_chunks.append(streaming_chunk) + streaming_callback(streaming_chunk) + + responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()] + return responses diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py new file mode 100644 index 000000000..8d08e0859 --- /dev/null +++ b/integrations/google_vertex/tests/test_gemini.py @@ -0,0 +1,256 @@ +from unittest.mock import MagicMock, Mock, patch + +from haystack.dataclasses import StreamingChunk +from vertexai.preview.generative_models import ( + FunctionDeclaration, + GenerationConfig, + HarmBlockThreshold, + HarmCategory, + Tool, +) + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator + +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_init(mock_vertexai_init, _mock_generative_model): + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + mock_vertexai_init.assert_called() + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == [tool] + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_to_dict(_mock_vertexai_init, _mock_generative_model): + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2.0, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_from_dict(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "tools": None, + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings is None + assert gemini._tools is None + assert gemini._generation_config is None + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) + assert isinstance(gemini._generation_config, GenerationConfig) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_run(mock_generative_model): + mock_model = Mock() + mock_model.generate_content.return_value = MagicMock() + mock_generative_model.return_value = mock_model + + gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + + response = gemini.run(["What's the weather like today?"]) + + mock_model.generate_content.assert_called_once() + assert "replies" in response + assert isinstance(response["replies"], list) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_run_with_streaming_callback(mock_generative_model): + mock_model = Mock() + mock_stream = [ + MagicMock(text="First part", usage_metadata={}), + MagicMock(text="Second part", usage_metadata={}), + ] + + mock_model.generate_content.return_value = mock_stream + mock_generative_model.return_value = mock_model + + streaming_callback_called = False + + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini.run(["Come on, stream!"]) + assert streaming_callback_called diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index a00c913d2..f66536fe5 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -2,15 +2,36 @@ ## [unreleased] +### ๐Ÿš€ Features + +- Update default embedding model to nvidia/nv-embedqa-e5-v5 (#1015) +- Add NVIDIA NIM ranker support (#1023) + +### ๐Ÿ› Bug Fixes + +- Lints in `nvidia-haystack` (#993) + ### ๐Ÿšœ Refactor - Remove deprecated Nvidia Cloud Functions backend and related code. (#803) +### ๐Ÿ“š Documentation + +- Update Nvidia API docs (#1031) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +### Docs + +- Update NvidiaGenerator docstrings (#966) + ## [integrations/nvidia-v0.0.3] - 2024-05-22 ### ๐Ÿ“š Documentation diff --git a/integrations/nvidia/pydoc/config.yml b/integrations/nvidia/pydoc/config.yml index 80bb212c5..a7ab228cf 100644 --- a/integrations/nvidia/pydoc/config.yml +++ b/integrations/nvidia/pydoc/config.yml @@ -5,7 +5,10 @@ loaders: [ "haystack_integrations.components.embedders.nvidia.document_embedder", "haystack_integrations.components.embedders.nvidia.text_embedder", + "haystack_integrations.components.embedders.nvidia.truncate", "haystack_integrations.components.generators.nvidia.generator", + "haystack_integrations.components.rankers.nvidia.ranker", + "haystack_integrations.components.rankers.nvidia.truncate", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 52c5393c4..3e911e4f4 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -93,7 +93,7 @@ def __init__( self._initialized = False if is_hosted(api_url) and not self.model: # manually set default model - self.model = "NV-Embed-QA" + self.model = "nvidia/nv-embedqa-e5-v5" def default_model(self): """Set default model in local NIM mode.""" diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 946fc08cb..0387c32b7 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -77,7 +77,7 @@ def __init__( self._initialized = False if is_hosted(api_url) and not self.model: # manually set default model - self.model = "NV-Embed-QA" + self.model = "nvidia/nv-embedqa-e5-v5" def default_model(self): """Set default model in local NIM mode.""" diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py new file mode 100644 index 000000000..29cb2f7f5 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py @@ -0,0 +1,3 @@ +from .ranker import NvidiaRanker + +__all__ = ["NvidiaRanker"] diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py new file mode 100644 index 000000000..46c736883 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -0,0 +1,206 @@ +import warnings +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace + +from haystack_integrations.utils.nvidia import NimBackend, url_validation + +from .truncate import RankerTruncateMode + +_DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" + +_MODEL_ENDPOINT_MAP = { + "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", +} + + +@component +class NvidiaRanker: + """ + A component for ranking documents using ranking models provided by + [NVIDIA NIMs](https://ai.nvidia.com). + + Usage example: + ```python + from haystack_integrations.components.rankers.nvidia import NvidiaRanker + from haystack import Document + from haystack.utils import Secret + + ranker = NvidiaRanker( + model="nvidia/nv-rerankqa-mistral-4b-v3", + api_key=Secret.from_env_var("NVIDIA_API_KEY"), + ) + ranker.warm_up() + + query = "What is the capital of Germany?" + documents = [ + Document(content="Berlin is the capital of Germany."), + Document(content="The capital of Germany is Berlin."), + Document(content="Germany's capital is Berlin."), + ] + + result = ranker.run(query, documents, top_k=2) + print(result["documents"]) + ``` + """ + + def __init__( + self, + model: Optional[str] = None, + truncate: Optional[Union[RankerTruncateMode, str]] = None, + api_url: Optional[str] = None, + api_key: Optional[Secret] = None, + top_k: int = 5, + ): + """ + Create a NvidiaRanker component. + + :param model: + Ranking model to use. + :param truncate: + Truncation strategy to use. Can be "NONE", "END", or RankerTruncateMode. Defaults to NIM's default. + :param api_key: + API key for the NVIDIA NIM. + :param api_url: + Custom API URL for the NVIDIA NIM. + :param top_k: + Number of documents to return. + """ + if model is not None and not isinstance(model, str): + msg = "Ranker expects the `model` parameter to be a string." + raise TypeError(msg) + if not isinstance(api_url, (str, type(None))): + msg = "Ranker expects the `api_url` parameter to be a string." + raise TypeError(msg) + if truncate is not None and not isinstance(truncate, RankerTruncateMode): + truncate = RankerTruncateMode.from_str(truncate) + if not isinstance(top_k, int): + msg = "Ranker expects the `top_k` parameter to be an integer." + raise TypeError(msg) + + # todo: detect default in non-hosted case (when api_url is provided) + self._model = model or _DEFAULT_MODEL + self._truncate = truncate + self._api_key = api_key + # if no api_url is provided, we're using a hosted model and can + # - assume the default url will work, because there's only one model + # - assume we won't call backend.models() + if api_url is not None: + self._api_url = url_validation(api_url, None, ["v1/ranking"]) + self._endpoint = None # we let backend.rank() handle the endpoint + else: + if self._model not in _MODEL_ENDPOINT_MAP: + msg = f"Model '{model}' is unknown. Please provide an api_url to access it." + raise ValueError(msg) + self._api_url = None # we handle the endpoint + self._endpoint = _MODEL_ENDPOINT_MAP[self._model] + if api_key is None: + self._api_key = Secret.from_env_var("NVIDIA_API_KEY") + self._top_k = top_k + self._initialized = False + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the ranker to a dictionary. + + :returns: A dictionary containing the ranker's attributes. + """ + return default_to_dict( + self, + model=self._model, + top_k=self._top_k, + truncate=self._truncate, + api_url=self._api_url, + api_key=self._api_key, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker": + """ + Deserialize the ranker from a dictionary. + + :param data: A dictionary containing the ranker's attributes. + :returns: The deserialized ranker. + """ + deserialize_secrets_inplace(data, keys=["api_key"]) + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initialize the ranker. + + :raises ValueError: If the API key is required for hosted NVIDIA NIMs. + """ + if not self._initialized: + model_kwargs = {} + if self._truncate is not None: + model_kwargs.update(truncate=str(self._truncate)) + self._backend = NimBackend( + self._model, + api_url=self._api_url, + api_key=self._api_key, + model_kwargs=model_kwargs, + ) + if not self._model: + self._model = _DEFAULT_MODEL + self._initialized = True + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + ) -> Dict[str, List[Document]]: + """ + Rank a list of documents based on a given query. + + :param query: The query to rank the documents against. + :param documents: The list of documents to rank. + :param top_k: The number of documents to return. + + :raises RuntimeError: If the ranker has not been loaded. + :raises TypeError: If the arguments are of the wrong type. + + :returns: A dictionary containing the ranked documents. + """ + if not self._initialized: + msg = "The ranker has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + if not isinstance(query, str): + msg = "Ranker expects the `query` parameter to be a string." + raise TypeError(msg) + if not isinstance(documents, list): + msg = "Ranker expects the `documents` parameter to be a list." + raise TypeError(msg) + if not all(isinstance(doc, Document) for doc in documents): + msg = "Ranker expects the `documents` parameter to be a list of Document objects." + raise TypeError(msg) + if top_k is not None and not isinstance(top_k, int): + msg = "Ranker expects the `top_k` parameter to be an integer." + raise TypeError(msg) + + if len(documents) == 0: + return {"documents": []} + + top_k = top_k if top_k is not None else self._top_k + if top_k < 1: + warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) + return {"documents": []} + + assert self._backend is not None + # rank result is list[{index: int, logit: float}] sorted by logit + sorted_indexes_and_scores = self._backend.rank( + query, + documents, + endpoint=self._endpoint, + ) + sorted_documents = [] + for item in sorted_indexes_and_scores[:top_k]: + # mutate (don't copy) the document because we're only updating the score + doc = documents[item["index"]] + doc.score = item["logit"] + sorted_documents.append(doc) + + return {"documents": sorted_documents} diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py new file mode 100644 index 000000000..3b5d7f40a --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py @@ -0,0 +1,27 @@ +from enum import Enum + + +class RankerTruncateMode(str, Enum): + """ + Specifies how inputs to the NVIDIA ranker components are truncated. + If NONE, the input will not be truncated and an error returned instead. + If END, the input will be truncated from the end. + """ + + NONE = "NONE" + END = "END" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "RankerTruncateMode": + """ + Create an truncate mode from a string. + + :param string: + String to convert. + :returns: + Truncate mode. + """ + return cls(string) diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index f69862f0e..0d1f57e5c 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple import requests +from haystack import Document from haystack.utils import Secret REQUEST_TIMEOUT = 60 @@ -129,3 +130,28 @@ def models(self) -> List[Model]: msg = f"No hosted model were found at URL '{url}'." raise ValueError(msg) return models + + def rank( + self, + query: str, + documents: List[Document], + endpoint: Optional[str] = None, + ) -> List[Dict[str, Any]]: + url = endpoint or f"{self.api_url}/ranking" + + res = self.session.post( + url, + json={ + "model": self.model, + "query": {"text": query}, + "passages": [{"text": doc.content} for doc in documents], + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + + data = res.json() + assert "rankings" in data, f"Expected 'rankings' in response, got {data}" + + return data["rankings"] diff --git a/integrations/nvidia/tests/conftest.py b/integrations/nvidia/tests/conftest.py index 794c994ff..a6c78ba4e 100644 --- a/integrations/nvidia/tests/conftest.py +++ b/integrations/nvidia/tests/conftest.py @@ -2,9 +2,10 @@ import pytest from haystack.utils import Secret -from haystack_integrations.utils.nvidia import Model, NimBackend from requests_mock import Mocker +from haystack_integrations.utils.nvidia import Model, NimBackend + class MockBackend(NimBackend): def __init__(self, model: str, api_key: Optional[Secret] = None, model_kwargs: Optional[Dict[str, Any]] = None): diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py index 072807685..426bacc25 100644 --- a/integrations/nvidia/tests/test_base_url.py +++ b/integrations/nvidia/tests/test_base_url.py @@ -1,6 +1,8 @@ import pytest + from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder from haystack_integrations.components.generators.nvidia import NvidiaGenerator +from haystack_integrations.components.rankers.nvidia import NvidiaRanker @pytest.mark.parametrize( @@ -14,12 +16,12 @@ ], ) @pytest.mark.parametrize( - "embedder", - [NvidiaDocumentEmbedder, NvidiaTextEmbedder], + "component", + [NvidiaDocumentEmbedder, NvidiaTextEmbedder, NvidiaRanker], ) -def test_base_url_invalid_not_hosted(base_url: str, embedder) -> None: +def test_base_url_invalid_not_hosted(base_url: str, component) -> None: with pytest.raises(ValueError): - embedder(api_url=base_url, model="x") + component(api_url=base_url, model="x") @pytest.mark.parametrize( @@ -62,3 +64,12 @@ def test_base_url_valid_generator(base_url: str) -> None: def test_base_url_invalid_generator(base_url: str) -> None: with pytest.raises(ValueError): NvidiaGenerator(api_url=base_url, model="x") + + +@pytest.mark.parametrize( + "base_url", + ["http://localhost:8080/v1/ranking", "http://0.0.0.0:8888/v1/ranking"], +) +def test_base_url_valid_ranker(base_url: str) -> None: + with pytest.warns(UserWarning): + NvidiaRanker(api_url=base_url) diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index 6562a0ea9..bef0f996e 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -3,6 +3,7 @@ import pytest from haystack import Document from haystack.utils import Secret + from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode, NvidiaDocumentEmbedder from . import MockBackend @@ -14,7 +15,7 @@ def test_init_default(self, monkeypatch): embedder = NvidiaDocumentEmbedder() assert embedder.api_key == Secret.from_env_var("NVIDIA_API_KEY") - assert embedder.model == "NV-Embed-QA" + assert embedder.model == "nvidia/nv-embedqa-e5-v5" assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" assert embedder.prefix == "" assert embedder.suffix == "" @@ -372,15 +373,34 @@ def test_run_integration_with_nim_backend(self): assert isinstance(doc.embedding, list) assert isinstance(doc.embedding[0], float) + @pytest.mark.parametrize( + "model, api_url", + [ + ("NV-Embed-QA", None), + ("snowflake/arctic-embed-l", "https://integrate.api.nvidia.com/v1"), + ("nvidia/nv-embed-v1", "https://integrate.api.nvidia.com/v1"), + ("nvidia/nv-embedqa-mistral-7b-v2", "https://integrate.api.nvidia.com/v1"), + ("nvidia/nv-embedqa-e5-v5", "https://integrate.api.nvidia.com/v1"), + ("baai/bge-m3", "https://integrate.api.nvidia.com/v1"), + ], + ids=[ + "NV-Embed-QA", + "snowflake/arctic-embed-l", + "nvidia/nv-embed-v1", + "nvidia/nv-embedqa-mistral-7b-v2", + "nvidia/nv-embedqa-e5-v5", + "baai/bge-m3", + ], + ) @pytest.mark.skipif( not os.environ.get("NVIDIA_API_KEY", None), reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", ) @pytest.mark.integration - def test_run_integration_with_api_catalog(self): + def test_run_integration_with_api_catalog(self, model, api_url): embedder = NvidiaDocumentEmbedder( - model="NV-Embed-QA", - api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia", + model=model, + **({"api_url": api_url} if api_url else {}), api_key=Secret.from_env_var("NVIDIA_API_KEY"), ) embedder.warm_up() diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 9fff9c2e8..0bd8b1fc6 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -5,9 +5,10 @@ import pytest from haystack.utils import Secret -from haystack_integrations.components.generators.nvidia import NvidiaGenerator from requests_mock import Mocker +from haystack_integrations.components.generators.nvidia import NvidiaGenerator + @pytest.fixture def mock_local_chat_completion(requests_mock: Mocker) -> None: diff --git a/integrations/nvidia/tests/test_ranker.py b/integrations/nvidia/tests/test_ranker.py new file mode 100644 index 000000000..566fd18a8 --- /dev/null +++ b/integrations/nvidia/tests/test_ranker.py @@ -0,0 +1,258 @@ +import os +import re +from typing import Any, Optional, Union + +import pytest +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.components.rankers.nvidia import NvidiaRanker +from haystack_integrations.components.rankers.nvidia.ranker import _DEFAULT_MODEL +from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode + + +class TestNvidiaRanker: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker() + assert client._model == _DEFAULT_MODEL + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + + def test_init_with_parameters(self): + client = NvidiaRanker( + api_key=Secret.from_token("fake-api-key"), + model=_DEFAULT_MODEL, + top_k=3, + truncate="END", + ) + assert client._api_key == Secret.from_token("fake-api-key") + assert client._model == _DEFAULT_MODEL + assert client._top_k == 3 + assert client._truncate == RankerTruncateMode.END + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("NVIDIA_API_KEY", raising=False) + client = NvidiaRanker() + with pytest.raises(ValueError): + client.warm_up() + + def test_init_pass_wo_api_key_w_api_url(self): + url = "https://url.bogus/v1" + client = NvidiaRanker(api_url=url) + assert client._api_url == url + + def test_warm_up_required(self): + client = NvidiaRanker() + with pytest.raises(RuntimeError) as e: + client.run("query", [Document(content="doc")]) + assert "not been loaded" in str(e.value) + + @pytest.mark.parametrize( + "truncate", + [ + None, + "END", + "NONE", + RankerTruncateMode.END, + RankerTruncateMode.NONE, + ], + ids=["None", "END-str", "NONE-str", "END-enum", "NONE-enum"], + ) + def test_mocked( + self, + requests_mock, + monkeypatch, + truncate: Optional[Union[RankerTruncateMode, str]], + ) -> None: + query = "What is it?" + documents = [ + Document(content="Nothing really."), + Document(content="Maybe something."), + Document(content="Not this."), + ] + + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + + requests_mock.post( + re.compile(r".*ranking"), + json={ + "rankings": [ + {"index": 1, "logit": 4.2}, + {"index": 0, "logit": 2.4}, + {"index": 2, "logit": -4.2}, + ] + }, + ) + + truncate_param = {} + if truncate: + truncate_param = {"truncate": truncate} + client = NvidiaRanker( + top_k=2, + **truncate_param, + ) + client.warm_up() + + response = client.run( + query=query, + documents=documents, + )["documents"] + + assert requests_mock.last_request is not None + request_payload = requests_mock.last_request.json() + if truncate is None: + assert "truncate" not in request_payload + else: + assert "truncate" in request_payload + assert request_payload["truncate"] == str(truncate) + + assert len(response) == 2 + assert response[0].content == documents[1].content + assert response[0].score == 4.2 + assert response[1].content == documents[0].content + assert response[1].score == 2.4 + + response = client.run( + query=query, + documents=documents, + top_k=1, + )["documents"] + assert len(response) == 1 + assert response[0].content == documents[1].content + assert response[0].score == 4.2 + + @pytest.mark.parametrize("truncate", [True, False, 1, 0, 1.0, "START", "BOGUS"]) + def test_truncate_invalid(self, truncate: Any) -> None: + with pytest.raises(ValueError) as e: + NvidiaRanker(truncate=truncate) + assert "not a valid RankerTruncateMode" in str(e.value) + + @pytest.mark.parametrize("top_k", [1.0, "BOGUS"]) + def test_top_k_invalid(self, monkeypatch, top_k: Any) -> None: + with pytest.raises(TypeError) as e: + NvidiaRanker(top_k=top_k) + assert "parameter to be an integer" in str(e.value) + + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker() + client.warm_up() + with pytest.raises(TypeError) as e: + client.run("query", [Document(content="doc")], top_k=top_k) + assert "parameter to be an integer" in str(e.value) + + @pytest.mark.skipif( + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", + ) + @pytest.mark.integration + def test_integration( + self, + ) -> None: + query = "What is it?" + documents = [ + Document(content="Nothing really."), + Document(content="Maybe something."), + Document(content="Not this."), + ] + + client = NvidiaRanker(top_k=2) + client.warm_up() + + response = client.run(query=query, documents=documents)["documents"] + + assert len(response) == 2 + assert {response[0].content, response[1].content} == {documents[0].content, documents[1].content} + + @pytest.mark.skipif( + not os.environ.get("NVIDIA_NIM_RANKER_MODEL", None) + or not os.environ.get("NVIDIA_NIM_RANKER_ENDPOINT_URL", None), + reason="Export an env var called NVIDIA_NIM_RANKER_MODEL containing the hosted model name and " + "NVIDIA_NIM_RANKER_ENDPOINT_URL containing the local URL to call.", + ) + @pytest.mark.integration + def test_nim_integration(self): + query = "What is it?" + documents = [ + Document(content="Nothing really."), + Document(content="Maybe something."), + Document(content="Not this."), + ] + + client = NvidiaRanker( + model=os.environ["NVIDIA_NIM_RANKER_MODEL"], + api_url=os.environ["NVIDIA_NIM_RANKER_ENDPOINT_URL"], + top_k=2, + ) + client.warm_up() + + response = client.run(query=query, documents=documents)["documents"] + + assert len(response) == 2 + assert {response[0].content, response[1].content} == {documents[0].content, documents[1].content} + + def test_top_k_warn(self, monkeypatch) -> None: + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + + client = NvidiaRanker(top_k=0) + client.warm_up() + with pytest.warns(UserWarning) as record0: + client.run("query", [Document(content="doc")]) + assert "top_k should be at least 1" in str(record0[0].message) + + client = NvidiaRanker(top_k=1) + client.warm_up() + with pytest.warns(UserWarning) as record1: + client.run("query", [Document(content="doc")], top_k=0) + assert "top_k should be at least 1" in str(record1[0].message) + + def test_model_typeerror(self) -> None: + with pytest.raises(TypeError) as e: + NvidiaRanker(model=1) + assert "parameter to be a string" in str(e.value) + + def test_api_url_typeerror(self) -> None: + with pytest.raises(TypeError) as e: + NvidiaRanker(api_url=1) + assert "parameter to be a string" in str(e.value) + + def test_query_typeerror(self, monkeypatch) -> None: + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker() + client.warm_up() + with pytest.raises(TypeError) as e: + client.run(1, [Document(content="doc")]) + assert "parameter to be a string" in str(e.value) + + def test_documents_typeerror(self, monkeypatch) -> None: + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker() + client.warm_up() + with pytest.raises(TypeError) as e: + client.run("query", "doc") + assert "parameter to be a list" in str(e.value) + + with pytest.raises(TypeError) as e: + client.run("query", [1]) + assert "list of Document objects" in str(e.value) + + def test_top_k_typeerror(self, monkeypatch) -> None: + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker() + client.warm_up() + with pytest.raises(TypeError) as e: + client.run("query", [Document(content="doc")], top_k="1") + assert "parameter to be an integer" in str(e.value) + + def test_model_unknown(self) -> None: + with pytest.raises(ValueError) as e: + NvidiaRanker(model="unknown-model") + assert "unknown" in str(e.value) + + def test_warm_up_once(self, monkeypatch) -> None: + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + + client = NvidiaRanker() + client.warm_up() + backend = client._backend + client.warm_up() + assert backend == client._backend diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 7c0a7000d..7c8428cc2 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -2,6 +2,7 @@ import pytest from haystack.utils import Secret + from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode, NvidiaTextEmbedder from . import MockBackend @@ -169,15 +170,34 @@ def test_run_integration_with_nim_backend(self): assert all(isinstance(x, float) for x in embedding) assert "usage" in meta + @pytest.mark.parametrize( + "model, api_url", + [ + ("NV-Embed-QA", None), + ("snowflake/arctic-embed-l", "https://integrate.api.nvidia.com/v1"), + ("nvidia/nv-embed-v1", "https://integrate.api.nvidia.com/v1"), + ("nvidia/nv-embedqa-mistral-7b-v2", "https://integrate.api.nvidia.com/v1"), + ("nvidia/nv-embedqa-e5-v5", "https://integrate.api.nvidia.com/v1"), + ("baai/bge-m3", "https://integrate.api.nvidia.com/v1"), + ], + ids=[ + "NV-Embed-QA", + "snowflake/arctic-embed-l", + "nvidia/nv-embed-v1", + "nvidia/nv-embedqa-mistral-7b-v2", + "nvidia/nv-embedqa-e5-v5", + "baai/bge-m3", + ], + ) @pytest.mark.skipif( not os.environ.get("NVIDIA_API_KEY", None), reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", ) @pytest.mark.integration - def test_run_integration_with_api_catalog(self): + def test_run_integration_with_api_catalog(self, model, api_url): embedder = NvidiaTextEmbedder( - model="NV-Embed-QA", - api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia", + model=model, + **({"api_url": api_url} if api_url else {}), api_key=Secret.from_env_var("NVIDIA_API_KEY"), ) embedder.warm_up() diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index f5d45d2b8..4a8478e2c 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -73,8 +73,16 @@ def __init__( An example `run()` method for this `custom_query`: ```python - retriever.run(query="Why did the revenue increase?", - filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + retriever.run( + query="Why did the revenue increase?", + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.years", "operator": "==", "value": "2019"}, + {"field": "meta.quarters", "operator": "in", "value": ["Q1", "Q2"]}, + ], + }, + ) ``` :param raise_on_failure: Whether to raise an exception if the API call fails. Otherwise log a warning and return an empty list. @@ -184,8 +192,16 @@ def run( **For this custom_query, a sample `run()` could be:** ```python - retriever.run(query="Why did the revenue increase?", - filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + retriever.run( + query="Why did the revenue increase?", + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.years", "operator": "==", "value": "2019"}, + {"field": "meta.quarters", "operator": "in", "value": ["Q1", "Q2"]}, + ], + }, + ) ``` :returns: diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 1c1071a76..e159634cf 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -71,8 +71,16 @@ def __init__( For this `custom_query`, an example `run()` could be: ```python - retriever.run(query_embedding=embedding, - filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + retriever.run( + query_embedding=embedding, + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.years", "operator": "==", "value": "2019"}, + {"field": "meta.quarters", "operator": "in", "value": ["Q1", "Q2"]}, + ], + }, + ) ``` :param raise_on_failure: If `True`, raises an exception if the API call fails. @@ -176,8 +184,16 @@ def run( For this `custom_query`, an example `run()` could be: ```python - retriever.run(query_embedding=embedding, - filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + retriever.run( + query_embedding=embedding, + filters={ + "operator": "AND", + "conditions": [ + {"field": "meta.years", "operator": "==", "value": "2019"}, + {"field": "meta.quarters", "operator": "in", "value": ["Q1", "Q2"]}, + ], + }, + ) ``` :returns: diff --git a/integrations/pinecone/CHANGELOG.md b/integrations/pinecone/CHANGELOG.md index d9d12505e..a041d63de 100644 --- a/integrations/pinecone/CHANGELOG.md +++ b/integrations/pinecone/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [unreleased] +## [integrations/pinecone-v1.2.3] - 2024-08-29 ### ๐Ÿš€ Features @@ -10,6 +10,7 @@ - `pinecone` - Fallback to default filter policy when deserializing retrievers without the init parameter (#901) - Skip unsupported meta fields in PineconeDB (#1009) +- Converting `Pinecone` metadata fields from float back to int (#1034) ### ๐Ÿงช Testing diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index 27eba6ecf..75d6270ca 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -267,6 +267,20 @@ def _embedding_retrieval( return self._convert_query_result_to_documents(result) + @staticmethod + def _convert_meta_to_int(metadata: Dict[str, Any]) -> Dict[str, Any]: + """ + Pinecone store numeric metadata values as `float`. Some specific metadata are used in Retrievers components and + are expected to be `int`. This method converts them back to integers. + """ + values_to_convert = ["split_id", "split_idx_start", "page_number"] + + for value in values_to_convert: + if value in metadata: + metadata[value] = int(metadata[value]) if isinstance(metadata[value], float) else metadata[value] + + return metadata + def _convert_query_result_to_documents(self, query_result: Dict[str, Any]) -> List[Document]: pinecone_docs = query_result["matches"] documents = [] @@ -278,8 +292,7 @@ def _convert_query_result_to_documents(self, query_result: Dict[str, Any]) -> Li if dataframe_string: dataframe = pd.read_json(io.StringIO(dataframe_string)) - # we always store vectors during writing - # but we don't want to return them if they are dummy vectors + # we always store vectors during writing but we don't want to return them if they are dummy vectors embedding = None if pinecone_doc["values"] != self._dummy_vector: embedding = pinecone_doc["values"] @@ -288,7 +301,7 @@ def _convert_query_result_to_documents(self, query_result: Dict[str, Any]) -> Li id=pinecone_doc["id"], content=content, dataframe=dataframe, - meta=pinecone_doc["metadata"], + meta=self._convert_meta_to_int(pinecone_doc["metadata"]), embedding=embedding, score=pinecone_doc["score"], ) diff --git a/integrations/pinecone/tests/test_document_store.py b/integrations/pinecone/tests/test_document_store.py index bd443b4a8..dcecf7996 100644 --- a/integrations/pinecone/tests/test_document_store.py +++ b/integrations/pinecone/tests/test_document_store.py @@ -5,10 +5,13 @@ import numpy as np import pytest from haystack import Document +from haystack.components.preprocessors import DocumentSplitter +from haystack.components.retrievers import SentenceWindowRetriever from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest from haystack.utils import Secret from pinecone import Pinecone, PodSpec, ServerlessSpec +from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever from haystack_integrations.document_stores.pinecone import PineconeDocumentStore @@ -178,6 +181,36 @@ def test_discard_invalid_meta_valid(): assert pinecone_doc.meta["page_number"] == 1 +def test_convert_meta_to_int(): + # Test with floats + meta_data = {"split_id": 1.0, "split_idx_start": 2.0, "page_number": 3.0} + assert PineconeDocumentStore._convert_meta_to_int(meta_data) == { + "split_id": 1, + "split_idx_start": 2, + "page_number": 3, + } + + # Test with floats and ints + meta_data = {"split_id": 1.0, "split_idx_start": 2, "page_number": 3.0} + assert PineconeDocumentStore._convert_meta_to_int(meta_data) == { + "split_id": 1, + "split_idx_start": 2, + "page_number": 3, + } + + # Test with floats and strings + meta_data = {"split_id": 1.0, "other": "other_data", "page_number": 3.0} + assert PineconeDocumentStore._convert_meta_to_int(meta_data) == { + "split_id": 1, + "other": "other_data", + "page_number": 3, + } + + # Test with empty dict + meta_data = {} + assert PineconeDocumentStore._convert_meta_to_int(meta_data) == {} + + @pytest.mark.integration @pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set") def test_serverless_index_creation_from_scratch(sleep_time): @@ -257,3 +290,28 @@ def test_embedding_retrieval(self, document_store: PineconeDocumentStore): assert len(results) == 2 assert results[0].content == "Most similar document" assert results[1].content == "2nd best document" + + def test_sentence_window_retriever(self, document_store: PineconeDocumentStore): + # indexing + splitter = DocumentSplitter(split_length=10, split_overlap=5, split_by="word") + text = ( + "Whose woods these are I think I know. His house is in the village though; He will not see me stopping " + "here To watch his woods fill up with snow." + ) + docs = splitter.run(documents=[Document(content=text)]) + + for idx, doc in enumerate(docs["documents"]): + if idx == 2: + doc.embedding = [0.1] * 768 + continue + doc.embedding = np.random.rand(768).tolist() + document_store.write_documents(docs["documents"]) + + # query + embedding_retriever = PineconeEmbeddingRetriever(document_store=document_store) + query_embedding = [0.1] * 768 + retrieved_doc = embedding_retriever.run(query_embedding=query_embedding, top_k=1, filters={}) + sentence_window_retriever = SentenceWindowRetriever(document_store=document_store, window_size=2) + result = sentence_window_retriever.run(retrieved_documents=[retrieved_doc["documents"][0]]) + + assert len(result["context_windows"]) == 1 diff --git a/integrations/qdrant/CHANGELOG.md b/integrations/qdrant/CHANGELOG.md index d17f549da..521f2b0c0 100644 --- a/integrations/qdrant/CHANGELOG.md +++ b/integrations/qdrant/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/qdrant-v4.2.0] - 2024-08-27 + +### ๐Ÿšœ Refactor + +- Qdrant Query API (#1025) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ## [integrations/qdrant-v4.1.2] - 2024-07-15 ### ๐Ÿ› Bug Fixes diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index d55cbd71c..0612373fb 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -15,7 +15,6 @@ from qdrant_client import grpc from qdrant_client.http import models as rest from qdrant_client.http.exceptions import UnexpectedResponse -from qdrant_client.hybrid.fusion import reciprocal_rank_fusion from tqdm import tqdm from .converters import ( @@ -537,20 +536,18 @@ def _query_by_sparse( qdrant_filters = convert_filters_to_qdrant(filters) query_indices = query_sparse_embedding.indices query_values = query_sparse_embedding.values - points = self.client.search( + points = self.client.query_points( collection_name=self.index, - query_vector=rest.NamedSparseVector( - name=SPARSE_VECTORS_NAME, - vector=rest.SparseVector( - indices=query_indices, - values=query_values, - ), + query=rest.SparseVector( + indices=query_indices, + values=query_values, ), + using=SPARSE_VECTORS_NAME, query_filter=qdrant_filters, limit=top_k, with_vectors=return_embedding, score_threshold=score_threshold, - ) + ).points results = [ convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) for point in points @@ -588,17 +585,15 @@ def _query_by_embedding( """ qdrant_filters = convert_filters_to_qdrant(filters) - points = self.client.search( + points = self.client.query_points( collection_name=self.index, - query_vector=rest.NamedVector( - name=DENSE_VECTORS_NAME if self.use_sparse_embeddings else "", - vector=query_embedding, - ), + query=query_embedding, + using=DENSE_VECTORS_NAME if self.use_sparse_embeddings else None, query_filter=qdrant_filters, limit=top_k, with_vectors=return_embedding, score_threshold=score_threshold, - ) + ).points results = [ convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) for point in points @@ -655,46 +650,34 @@ def _query_hybrid( qdrant_filters = convert_filters_to_qdrant(filters) - sparse_request = rest.SearchRequest( - vector=rest.NamedSparseVector( - name=SPARSE_VECTORS_NAME, - vector=rest.SparseVector( - indices=query_sparse_embedding.indices, - values=query_sparse_embedding.values, - ), - ), - filter=qdrant_filters, - limit=top_k, - with_payload=True, - with_vector=return_embedding, - score_threshold=score_threshold, - ) - - dense_request = rest.SearchRequest( - vector=rest.NamedVector( - name=DENSE_VECTORS_NAME, - vector=query_embedding, - ), - filter=qdrant_filters, - limit=top_k, - with_payload=True, - with_vector=return_embedding, - ) - try: - dense_request_response, sparse_request_response = self.client.search_batch( - collection_name=self.index, requests=[dense_request, sparse_request] - ) + points = self.client.query_points( + collection_name=self.index, + prefetch=[ + rest.Prefetch( + query=rest.SparseVector( + indices=query_sparse_embedding.indices, + values=query_sparse_embedding.values, + ), + using=SPARSE_VECTORS_NAME, + filter=qdrant_filters, + ), + rest.Prefetch( + query=query_embedding, + using=DENSE_VECTORS_NAME, + filter=qdrant_filters, + ), + ], + query=rest.FusionQuery(fusion=rest.Fusion.RRF), + limit=top_k, + score_threshold=score_threshold, + with_payload=True, + with_vectors=return_embedding, + ).points except Exception as e: msg = "Error during hybrid search" raise QdrantStoreError(msg) from e - try: - points = reciprocal_rank_fusion(responses=[dense_request_response, sparse_request_response], limit=top_k) - except Exception as e: - msg = "Error while applying Reciprocal Rank Fusion" - raise QdrantStoreError(msg) from e - results = [convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) for point in points] return results diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py index 5559d7bac..112b7e5ac 100644 --- a/integrations/qdrant/tests/test_document_store.py +++ b/integrations/qdrant/tests/test_document_store.py @@ -114,19 +114,7 @@ def test_query_hybrid_search_batch_failure(self): sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) embedding = [0.1] * 768 - with patch.object(document_store.client, "search_batch", side_effect=Exception("search_batch error")): + with patch.object(document_store.client, "query_points", side_effect=Exception("query_points")): with pytest.raises(QdrantStoreError): document_store._query_hybrid(query_sparse_embedding=sparse_embedding, query_embedding=embedding) - - @patch("haystack_integrations.document_stores.qdrant.document_store.reciprocal_rank_fusion") - def test_query_hybrid_reciprocal_rank_fusion_failure(self, mocked_fusion): - document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) - - sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) - embedding = [0.1] * 768 - - mocked_fusion.side_effect = Exception("reciprocal_rank_fusion error") - - with pytest.raises(QdrantStoreError): - document_store._query_hybrid(query_sparse_embedding=sparse_embedding, query_embedding=embedding)