diff --git a/integrations/cohere/CHANGELOG.md b/integrations/cohere/CHANGELOG.md index 3067b0a5e..3f36836cc 100644 --- a/integrations/cohere/CHANGELOG.md +++ b/integrations/cohere/CHANGELOG.md @@ -1,15 +1,30 @@ # Changelog -## [unreleased] +## [integrations/cohere-v2.0.0] - 2024-09-16 ### ๐Ÿš€ Features - Update Anthropic/Cohere for tools use (#790) - Update Cohere default LLMs, add examples and update unit tests (#838) +- Cohere LLM - adjust token counting meta to match OpenAI format (#1086) + +### ๐Ÿ› Bug Fixes + +- Lints in `cohere-haystack` (#995) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +### Docs + +- Update CohereChatGenerator docstrings (#958) +- Update CohereGenerator docstrings (#960) ## [integrations/cohere-v1.1.1] - 2024-06-12 diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 568a26979..e635e291c 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -178,7 +178,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, if finish_response.meta.billed_units: tokens_in = finish_response.meta.billed_units.input_tokens or -1 tokens_out = finish_response.meta.billed_units.output_tokens or -1 - chat_message.meta["usage"] = tokens_in + tokens_out + chat_message.meta["usage"] = {"prompt_tokens": tokens_in, "completion_tokens": tokens_out} chat_message.meta.update( { "model": self.model, @@ -220,11 +220,13 @@ def _build_message(self, cohere_response): message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) elif cohere_response.text: message = ChatMessage.from_assistant(content=cohere_response.text) - total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens message.meta.update( { "model": self.model, - "usage": total_tokens, + "usage": { + "prompt_tokens": cohere_response.meta.billed_units.input_tokens, + "completion_tokens": cohere_response.meta.billed_units.output_tokens, + }, "index": 0, "finish_reason": cohere_response.finish_reason, "documents": cohere_response.documents, diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 6521503f2..fe9b7f43e 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -169,6 +169,9 @@ def test_live_run(self): assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.content + assert "usage" in message.meta + assert "prompt_tokens" in message.meta["usage"] + assert "completion_tokens" in message.meta["usage"] @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), @@ -210,6 +213,10 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + assert "usage" in message.meta + assert "prompt_tokens" in message.meta["usage"] + assert "completion_tokens" in message.meta["usage"] + @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 0a90a7121..ccd68ded3 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog -## [unreleased] +## [integrations/langfuse-v0.4.0] - 2024-09-17 + +### ๐Ÿš€ Features + +- Langfuse - support generation span for more LLMs (#1087) ### ๐Ÿšœ Refactor diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py index 443d65a13..0d9c42787 100644 --- a/integrations/langfuse/example/chat.py +++ b/integrations/langfuse/example/chat.py @@ -1,19 +1,46 @@ import os +# See README.md for more information on how to set up the environment variables +# before running this script + +# In addition to setting the environment variables, you need to install the following packages: +# pip install cohere-haystack anthropic-haystack os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.generators.chat import HuggingFaceAPIChatGenerator, OpenAIChatGenerator from haystack.dataclasses import ChatMessage +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType + from haystack_integrations.components.connectors.langfuse import LangfuseConnector +from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator + +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +selected_chat_generator = "openai" + +generators = { + "openai": OpenAIChatGenerator, + "anthropic": AnthropicChatGenerator, + "hf_api": lambda: HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "mistralai/Mixtral-8x7B-Instruct-v0.1"}, + token=Secret.from_token(os.environ["HF_API_KEY"]), + ), + "cohere": CohereChatGenerator, +} + +selected_chat_generator = generators[selected_chat_generator]() if __name__ == "__main__": pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector("Chat example")) pipe.add_component("prompt_builder", ChatPromptBuilder()) - pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) + pipe.add_component("llm", selected_chat_generator) pipe.connect("prompt_builder.prompt", "llm.messages") diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index d92c62668..6f9213be7 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -47,6 +47,8 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "anthropic-haystack", + "cohere-haystack" ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 7d141c08c..94064a0d1 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -10,8 +10,22 @@ import langfuse HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" -_SUPPORTED_GENERATORS = ["AzureOpenAIGenerator", "OpenAIGenerator"] -_SUPPORTED_CHAT_GENERATORS = ["AzureOpenAIChatGenerator", "OpenAIChatGenerator"] +_SUPPORTED_GENERATORS = [ + "AzureOpenAIGenerator", + "OpenAIGenerator", + "AnthropicGenerator", + "HuggingFaceAPIGenerator", + "HuggingFaceLocalGenerator", + "CohereGenerator", +] +_SUPPORTED_CHAT_GENERATORS = [ + "AzureOpenAIChatGenerator", + "OpenAIChatGenerator", + "AnthropicChatGenerator", + "HuggingFaceAPIChatGenerator", + "HuggingFaceLocalChatGenerator", + "CohereChatGenerator", +] _ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 111d89dfd..4e8c679be 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -1,34 +1,38 @@ import os - -# don't remove (or move) this env var setting from here, it's needed to turn tracing on -os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" - -from urllib.parse import urlparse - import pytest +from urllib.parse import urlparse import requests - +from requests.auth import HTTPBasicAuth from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage -from requests.auth import HTTPBasicAuth - from haystack_integrations.components.connectors.langfuse import LangfuseConnector +from haystack.components.generators.chat import OpenAIChatGenerator + +from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator + +# don't remove (or move) this env var setting from here, it's needed to turn tracing on +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" @pytest.mark.integration -@pytest.mark.skipif( - not os.environ.get("LANGFUSE_SECRET_KEY", None) and not os.environ.get("LANGFUSE_PUBLIC_KEY", None), - reason="Export an env var called LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY containing Langfuse credentials.", +@pytest.mark.parametrize( + "llm_class, env_var, expected_trace", + [ + (OpenAIChatGenerator, "OPENAI_API_KEY", "OpenAI"), + (AnthropicChatGenerator, "ANTHROPIC_API_KEY", "Anthropic"), + (CohereChatGenerator, "COHERE_API_KEY", "Cohere"), + ], ) -def test_tracing_integration(): +def test_tracing_integration(llm_class, env_var, expected_trace): + if not all([os.environ.get("LANGFUSE_SECRET_KEY"), os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get(env_var)]): + pytest.skip(f"Missing required environment variables: LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, or {env_var}") pipe = Pipeline() - pipe.add_component("tracer", LangfuseConnector(name="Chat example", public=True)) # public so anyone can verify run + pipe.add_component("tracer", LangfuseConnector(name=f"Chat example - {expected_trace}", public=True)) pipe.add_component("prompt_builder", ChatPromptBuilder()) - pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) - + pipe.add_component("llm", llm_class()) pipe.connect("prompt_builder.prompt", "llm.messages") messages = [ @@ -39,17 +43,20 @@ def test_tracing_integration(): response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] + url = "https://cloud.langfuse.com/api/public/traces/" trace_url = response["tracer"]["trace_url"] - parsed_url = urlparse(trace_url) - # trace id is the last part of the path (after the last '/') - uuid = os.path.basename(parsed_url.path) + uuid = os.path.basename(urlparse(trace_url).path) + try: - # GET request with Basic Authentication on the Langfuse API response = requests.get( - url + uuid, auth=HTTPBasicAuth(os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get("LANGFUSE_SECRET_KEY")) + url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) ) - assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + + # check if the trace contains the expected LLM name + assert expected_trace in str(response.content) + # check if the trace contains the expected generation span + assert "GENERATION" in str(response.content) except requests.exceptions.RequestException as e: - assert False, f"Failed to retrieve data from Langfuse API: {e}" + pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md new file mode 100644 index 000000000..0553a3f4b --- /dev/null +++ b/integrations/snowflake/CHANGELOG.md @@ -0,0 +1 @@ +## [integrations/snowflake-v0.0.1] - 2024-09-06 \ No newline at end of file diff --git a/integrations/snowflake/README.md b/integrations/snowflake/README.md new file mode 100644 index 000000000..30f0aee1a --- /dev/null +++ b/integrations/snowflake/README.md @@ -0,0 +1,23 @@ +# snowflake-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [License](#license) + +## Installation + +```console +pip install snowflake-haystack +``` +## Examples +You can find a code example showing how to use the Retriever under the `example/` folder of this repo. + +## License + +`snowflake-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. \ No newline at end of file diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py new file mode 100644 index 000000000..b85a4c677 --- /dev/null +++ b/integrations/snowflake/example/text2sql_example.py @@ -0,0 +1,120 @@ +from dotenv import load_dotenv +from haystack import Pipeline +from haystack.components.builders import PromptBuilder +from haystack.components.converters import OutputAdapter +from haystack.components.generators import OpenAIGenerator +from haystack.utils import Secret + +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever + +load_dotenv() + +sql_template = """ + You are a SQL expert working with Snowflake. + + Your task is to create a Snowflake SQL query for the given question. + + Refrain from explaining your answer. Your answer must be the SQL query + in plain text format without using Markdown. + + Here are some relevant tables, a description about it, and their + columns: + + Database name: DEMO_DB + Schema name: ADVENTURE_WORKS + Table names: + - ADDRESS: Employees Address Table + - EMPLOYEE: Employees directory + - SALESTERRITORY: Sales territory lookup table. + - SALESORDERHEADER: General sales order information. + + User's question: {{ question }} + + Generated SQL query: +""" + +sql_builder = PromptBuilder(template=sql_template) + +analyst_template = """ + You are an expert data analyst. + + Your role is to answer the user's question {{ question }} using the information + in the table. + + You will base your response solely on the information provided in the + table. + + Do not rely on your knowledge base; only the data that is in the table. + + Refrain from using the term "table" in your response, but instead, use + the word "data" + + If the table is blank say: + + "The specific answer can't be found in the database. Try rephrasing your + question." + + Additionally, you will present the table in a tabular format and provide + the SQL query used to extract the relevant rows from the database in + Markdown. + + If the table is larger than 10 rows, display the most important rows up + to 10 rows. Your answer must be detailed and provide insights based on + the question and the available data. + + SQL query: + + {{ sql_query }} + + Table: + + {{ table }} + + Answer: +""" + +analyst_builder = PromptBuilder(template=analyst_template) + +# LLM responsible for generating the SQL query +sql_llm = OpenAIGenerator( + model="gpt-4o", + api_key=Secret.from_env_var("OPENAI_API_KEY"), + generation_kwargs={"temperature": 0.0, "max_tokens": 1000}, +) + +# LLM responsible for analyzing the table +analyst_llm = OpenAIGenerator( + model="gpt-4o", + api_key=Secret.from_env_var("OPENAI_API_KEY"), + generation_kwargs={"temperature": 0.0, "max_tokens": 2000}, +) + +snowflake = SnowflakeTableRetriever( + user="", + account="", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + warehouse="", +) + +adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) + +pipeline = Pipeline() + +pipeline.add_component(name="sql_builder", instance=sql_builder) +pipeline.add_component(name="sql_llm", instance=sql_llm) +pipeline.add_component(name="adapter", instance=adapter) +pipeline.add_component(name="snowflake", instance=snowflake) +pipeline.add_component(name="analyst_builder", instance=analyst_builder) +pipeline.add_component(name="analyst_llm", instance=analyst_llm) + + +pipeline.connect("sql_builder.prompt", "sql_llm.prompt") +pipeline.connect("sql_llm.replies", "adapter.replies") +pipeline.connect("adapter.output", "snowflake.query") +pipeline.connect("snowflake.table", "analyst_builder.table") +pipeline.connect("adapter.output", "analyst_builder.sql_query") +pipeline.connect("analyst_builder.prompt", "analyst_llm.prompt") + +question = "What are my top territories by number of orders and by sales value?" + +response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) diff --git a/integrations/snowflake/pydoc/config.yml b/integrations/snowflake/pydoc/config.yml new file mode 100644 index 000000000..7237b3816 --- /dev/null +++ b/integrations/snowflake/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: + [ + "haystack_integrations.components.retrievers.snowflake.snowflake_retriever" + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Snowflake integration for Haystack + category_slug: integrations-api + title: Snowflake + slug: integrations-Snowflake + order: 130 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_snowflake.md \ No newline at end of file diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml new file mode 100644 index 000000000..68f9ec477 --- /dev/null +++ b/integrations/snowflake/pyproject.toml @@ -0,0 +1,149 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "snowflake-haystack" +dynamic = ["version"] +description = 'A Snowflake integration for the Haystack framework.' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Mohamed Sriha", email = "mohamed.sriha@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "snowflake-connector-python>=3.10.1", "tabulate>=0.9.0"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/snowflake-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +all = ["style", "typing"] + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Ignore SQL injection + "S608", + # Unused method argument + "ARG002" +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["snowflake_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "openai.*", "snowflake.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py new file mode 100644 index 000000000..294d3cce4 --- /dev/null +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .snowflake_table_retriever import SnowflakeTableRetriever + +__all__ = ["SnowflakeTableRetriever"] diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py new file mode 100644 index 000000000..aa6f5ff4d --- /dev/null +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import re +from typing import Any, Dict, Final, Optional, Union + +import pandas as pd +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace + +with LazyImport("Run 'pip install snowflake-connector-python>=3.10.1'") as snow_import: + import snowflake.connector + from snowflake.connector.connection import SnowflakeConnection + from snowflake.connector.errors import ( + DatabaseError, + ForbiddenError, + ProgrammingError, + ) + +logger = logging.getLogger(__name__) + +MAX_SYS_ROWS: Final = 1000000 # Max rows to fetch from a table + + +@component +class SnowflakeTableRetriever: + """ + Connects to a Snowflake database to execute a SQL query. + For more information, see [Snowflake documentation](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector). + + ### Usage example: + + ```python + executor = SnowflakeTableRetriever( + user="", + account="", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + database="", + db_schema="", + warehouse="", + ) + + # When database and schema are provided during component initialization. + query = "SELECT * FROM table_name" + + # or + + # When database and schema are NOT provided during component initialization. + query = "SELECT * FROM database_name.schema_name.table_name" + + results = executor.run(query=query) + + print(results["dataframe"].head(2)) # Pandas dataframe + # Column 1 Column 2 + # 0 Value1 Value2 + # 1 Value1 Value2 + + print(results["table"]) # Markdown + # | Column 1 | Column 2 | + # |:----------|:----------| + # | Value1 | Value2 | + # | Value1 | Value2 | + ``` + """ + + def __init__( + self, + user: str, + account: str, + api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + database: Optional[str] = None, + db_schema: Optional[str] = None, + warehouse: Optional[str] = None, + login_timeout: Optional[int] = None, + ) -> None: + """ + :param user: User's login. + :param account: Snowflake account identifier. + :param api_key: Snowflake account password. + :param database: Name of the database to use. + :param db_schema: Name of the schema to use. + :param warehouse: Name of the warehouse to use. + :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + """ + + self.user = user + self.account = account + self.api_key = api_key + self.database = database + self.db_schema = db_schema + self.warehouse = warehouse + self.login_timeout = login_timeout or 60 + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + user=self.user, + account=self.account, + api_key=self.api_key.to_dict(), + database=self.database, + db_schema=self.db_schema, + warehouse=self.warehouse, + login_timeout=self.login_timeout, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) + + @staticmethod + def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConnection, None]: + """ + Connect to a Snowflake database. + + :param connect_params: Snowflake connection parameters. + """ + try: + return snowflake.connector.connect(**connect_params) + except DatabaseError as e: + logger.error("{error_msg} ", errno=e.errno, error_msg=e.msg) + return None + + @staticmethod + def _extract_table_names(query: str) -> list: + """ + Extract table names from an SQL query using regex. + The extracted table names will be checked for privilege. + + :param query: SQL query to extract table names from. + """ + + suffix = "\\s+([a-zA-Z0-9_.]+)" # Regular expressions to match table names in various clauses + + patterns = [ + "MERGE\\s+INTO", + "USING", + "JOIN", + "FROM", + "UPDATE", + "DROP\\s+TABLE", + "TRUNCATE\\s+TABLE", + "CREATE\\s+TABLE", + "INSERT\\s+INTO", + "DELETE\\s+FROM", + ] + + # Combine all patterns into a single regex + combined_pattern = "|".join([pattern + suffix for pattern in patterns]) + + # Find all matches in the query + matches = re.findall(pattern=combined_pattern, string=query, flags=re.IGNORECASE) + + # Flatten the list of tuples and remove duplication + matches = list(set(sum(matches, ()))) + + # Clean and return unique table names + return [match.strip('`"[]').upper() for match in matches if match] + + @staticmethod + def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: + """ + Execute an SQL query and fetch the results. + + :param conn: An open connection to Snowflake. + :param query: The query to execute. + """ + cur = conn.cursor() + try: + cur.execute(query) + rows = cur.fetchmany(size=MAX_SYS_ROWS) # set a limit to avoid fetching too many rows + + df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) # Convert data to a dataframe + return df + except Exception as e: + if isinstance(e, ProgrammingError): + logger.warning( + "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", + error_msg=e.msg, + sfqid=e.sfqid, + ) + else: + logger.warning("An unexpected error occurred: {error_msg}", error_msg=e) + + return pd.DataFrame() + + @staticmethod + def _has_select_privilege(privileges: list, table_name: str) -> bool: + """ + Check user's privilege for a specific table. + + :param privileges: List of privileges. + :param table_name: Name of the table. + """ + + for privilege in reversed(privileges): + if table_name.lower() == privilege[3].lower() and re.match( + pattern="truncate|update|insert|delete|operate|references", + string=privilege[1], + flags=re.IGNORECASE, + ): + return False + + return True + + def _check_privilege( + self, + conn: SnowflakeConnection, + query: str, + user: str, + ) -> bool: + """ + Check whether a user has a `select`-only access to the table. + + :param conn: An open connection to Snowflake. + :param query: The query from where to extract table names to check read-only access. + """ + cur = conn.cursor() + + cur.execute(f"SHOW GRANTS TO USER {user};") + + # Get user's latest role + roles = cur.fetchall() + if not roles: + logger.error("User does not exist") + return False + + # Last row second column from GRANT table + role = roles[-1][1] + + # Get role privilege + cur.execute(f"SHOW GRANTS TO ROLE {role};") + + # Keep table level privileges + table_privileges = [row for row in cur.fetchall() if row[2] == "TABLE"] + + # Get table names to check for privilege + table_names = self._extract_table_names(query=query) + + for table_name in table_names: + if not self._has_select_privilege( + privileges=table_privileges, + table_name=table_name, + ): + return False + return True + + def _fetch_data( + self, + query: str, + ) -> pd.DataFrame: + """ + Fetch data from a database using a SQL query. + + :param query: SQL query to use to fetch the data from the database. Query must be a valid SQL query. + """ + + df = pd.DataFrame() + if not query: + return df + try: + # Create a new connection with every run + conn = self._snowflake_connector( + connect_params={ + "user": self.user, + "account": self.account, + "password": self.api_key.resolve_value(), + "database": self.database, + "schema": self.db_schema, + "warehouse": self.warehouse, + "login_timeout": self.login_timeout, + } + ) + if conn is None: + return df + except (ForbiddenError, ProgrammingError) as e: + logger.error( + "Error connecting to Snowflake ({errno}): {error_msg}", + errno=e.errno, + error_msg=e.msg, + ) + return df + + try: + # Check if user has `select` privilege on the table + if self._check_privilege( + conn=conn, + query=query, + user=self.user, + ): + df = self._execute_sql_query(conn=conn, query=query) + else: + logger.error("User does not have `Select` privilege on the table.") + + except Exception as e: + logger.error("An unexpected error has occurred: {error}", error=e) + + # Close connection after every execution + conn.close() + return df + + @component.output_types(dataframe=pd.DataFrame, table=str) + def run(self, query: str) -> Dict[str, Any]: + """ + Execute a SQL query against a Snowflake database. + + :param query: A SQL query to execute. + """ + if not query: + logger.error("Provide a valid SQL query.") + return { + "dataframe": pd.DataFrame, + "table": "", + } + else: + df = self._fetch_data(query) + table_markdown = df.to_markdown(index=False) if not df.empty else "" + + return {"dataframe": df, "table": table_markdown} diff --git a/integrations/snowflake/tests/__init__.py b/integrations/snowflake/tests/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/snowflake/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py new file mode 100644 index 000000000..547f7e1b1 --- /dev/null +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from typing import Generator +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +from dateutil.tz import tzlocal +from haystack import Pipeline +from haystack.components.converters import OutputAdapter +from haystack.components.generators import OpenAIGenerator +from haystack.components.builders import PromptBuilder +from haystack.utils import Secret +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from pytest import LogCaptureFixture +from snowflake.connector.errors import DatabaseError, ForbiddenError, ProgrammingError + +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever + + +class TestSnowflakeTableRetriever: + @pytest.fixture + def snowflake_table_retriever(self) -> SnowflakeTableRetriever: + return SnowflakeTableRetriever( + user="test_user", + account="test_account", + api_key=Secret.from_token("test-api-key"), + database="test_database", + db_schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_snowflake_connector( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + + conn = snowflake_table_retriever._snowflake_connector( + connect_params={ + "user": "test_user", + "account": "test_account", + "api_key": Secret.from_token("test-api-key"), + "database": "test_database", + "schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 30, + } + ) + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + api_key=Secret.from_token("test-api-key"), + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + assert conn == mock_conn + + def test_query_is_empty( + self, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + query = "" + result = snowflake_table_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + assert "Provide a valid SQL query" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_exception( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect = mock_connect.return_value + mock_connect._fetch_data.side_effect = Exception("Unknown error") + + query = 4 + result = snowflake_table_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + + assert "An unexpected error has occurred" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_forbidden_error_during_connection( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect.side_effect = ForbiddenError(msg="Forbidden error", errno=403) + + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "000403: Forbidden error" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_programing_error_during_connection( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect.side_effect = ProgrammingError(msg="Programming error", errno=403) + + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "000403: Programming error" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_execute_sql_query_programming_error( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_cursor = mock_conn.cursor.return_value + + mock_cursor.execute.side_effect = ProgrammingError(msg="Simulated programming error", sfqid="ABC-123") + + result = snowflake_table_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") + + assert result.empty + + assert ( + "Simulated programming error Use the following ID to check the status of " + "the query in Snowflake UI (ID: ABC-123)" in caplog.text + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_connection_error( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_connect.side_effect = DatabaseError(msg="Connection error", errno=1234) + + query = "SELECT * FROM test_table" + result = snowflake_table_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + + def test_extract_single_table_name(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + queries = [ + "SELECT * FROM table_a", + "SELECT name, value FROM (SELECT name, value FROM table_a) AS subquery", + "SELECT name, value FROM (SELECT name, value FROM table_a ) AS subquery", + "UPDATE table_a SET value = 'new_value' WHERE id = 1", + "INSERT INTO table_a (id, name, value) VALUES (1, 'name1', 'value1')", + "DELETE FROM table_a WHERE id = 1", + "TRUNCATE TABLE table_a", + "DROP TABLE table_a", + ] + for query in queries: + result = snowflake_table_retriever._extract_table_names(query) + assert result == ["TABLE_A"] + + def test_extract_database_and_schema_from_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + # when database and schema are next to table name + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ + "DB.SCHEMA.TABLE_A" + ] + # No database + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == [ + "SCHEMA.TABLE_A" + ] + + def test_extract_multiple_table_names(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + queries = [ + "MERGE INTO table_a USING table_b ON table_a.id = table_b.id WHEN MATCHED", + "SELECT a.name, b.value FROM table_a AS a FULL OUTER JOIN table_b AS b ON a.id = b.id", + "SELECT a.name, b.value FROM table_a AS a RIGHT JOIN table_b AS b ON a.id = b.id", + ] + for query in queries: + result = snowflake_table_retriever._extract_table_names(query) + # Due to using set when deduplicating + assert result == ["TABLE_A", "TABLE_B"] or ["TABLE_B", "TABLE_A"] + + def test_extract_multiple_db_schema_from_table_names( + self, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + assert ( + snowflake_table_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a + FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) + == ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"] + or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B"] + ) + # No database + assert ( + snowflake_table_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a + FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) + == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] + or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_execute_sql_query( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_cursor.description = [mock_col1, mock_col2] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM test_table" + expected = pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) + + mock_cursor.execute.assert_called_once_with(query) + + assert result.equals(expected) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_is_select_only( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "LOCATIONS", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], # Table privileges + ] + + query = "select * from locations" + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + assert result + + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "INSERT", + "TABLE", + "LOCATIONS", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + + assert not result + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_column_after_from( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_col1.name = "id" + mock_col2.name = "year" + mock_cursor.fetchmany.return_value = [(1233, 1998)] + mock_cursor.description = [mock_col1, mock_col2] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT id, extract(year from date_col) as year FROM test_table" + expected = pd.DataFrame(data={"id": [1233], "year": [1998]}) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) + mock_cursor.execute.assert_called_once_with(query) + + assert result.equals(expected) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "locations", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.description = [mock_col1, mock_col2] + + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM locations" + + expected = { + "dataframe": pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}), + "table": "| City | State |\n|:--------|:---------|\n| Chicago | Illinois |", + } + + result = snowflake_table_retriever.run(query=query) + + assert result["dataframe"].equals(expected["dataframe"]) + assert result["table"] == expected["table"] + + @pytest.fixture + def mock_chat_completion(self) -> Generator: + """ + Mock the OpenAI API completion response and reuse it for tests + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4o-mini", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + logprobs=None, + index=0, + message=ChatCompletionMessage(content="select locations from table_a", role="assistant"), + ) + ], + created=int(datetime.now(tz=tzlocal()).timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_pipeline( + self, + mock_connect: MagicMock, + mock_chat_completion: MagicMock, + snowflake_table_retriever: SnowflakeTableRetriever, + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "test_database.test_schema.table_a", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "locations" + + mock_cursor.description = [mock_col1] + + mock_cursor.fetchmany.return_value = [("Chicago",), ("Miami",), ("Berlin",)] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + expected = { + "dataframe": pd.DataFrame(data={"locations": ["Chicago", "Miami", "Berlin"]}), + "table": "| locations |\n|:------------|\n| Chicago |\n| Miami |\n| Berlin |", + } + + llm = OpenAIGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key")) + adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) + pipeline = Pipeline() + + pipeline.add_component("llm", llm) + pipeline.add_component("adapter", adapter) + pipeline.add_component("snowflake", snowflake_table_retriever) + + pipeline.connect(sender="llm.replies", receiver="adapter.replies") + pipeline.connect(sender="adapter.output", receiver="snowflake.query") + + result = pipeline.run(data={"llm": {"prompt": "Generate a SQL query that extract all locations from table_a"}}) + + assert result["snowflake"]["dataframe"].equals(expected["dataframe"]) + assert result["snowflake"]["table"] == expected["table"] + + def test_from_dict(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + data = { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever" + ".SnowflakeTableRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "test_user", + "account": "new_account", + "database": "test_database", + "db_schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 3, + }, + } + component = SnowflakeTableRetriever.from_dict(data) + + assert component.user == "test_user" + assert component.account == "new_account" + assert component.api_key == Secret.from_env_var("SNOWFLAKE_API_KEY") + assert component.database == "test_database" + assert component.db_schema == "test_schema" + assert component.warehouse == "test_warehouse" + assert component.login_timeout == 3 + + def test_to_dict_default(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + component = SnowflakeTableRetriever( + user="test_user", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + account="test_account", + database="test_database", + db_schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "test_user", + "account": "test_account", + "database": "test_database", + "db_schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 30, + }, + } + + def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + component = SnowflakeTableRetriever( + user="John", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + account="TGMD-EEREW", + database="CITY", + db_schema="SMALL_TOWNS", + warehouse="COMPUTE_WH", + login_timeout=30, + ) + + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "John", + "account": "TGMD-EEREW", + "database": "CITY", + "db_schema": "SMALL_TOWNS", + "warehouse": "COMPUTE_WH", + "login_timeout": 30, + }, + } + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_has_select_privilege( + self, mock_logger: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + # Define test cases + test_cases = [ + # Test case 1: Fully qualified table name in query + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 2: Schema and table names in query, database name as argument + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 3: Only table name in query, database and schema names as arguments + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 5: Privilege does not match + { + "privileges": [[None, "INSERT", None, "table"]], + "table_name": "table", + "expected_result": False, + }, + # Test case 6: Case-insensitive match + { + "privileges": [[None, "select", None, "table"]], + "table_name": "TABLE", + "expected_result": True, + }, + ] + + for case in test_cases: + result = snowflake_table_retriever._has_select_privilege( + privileges=case["privileges"], # type: ignore + table_name=case["table_name"], # type: ignore + ) + assert result == case["expected_result"] # type: ignore + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_user_does_not_exist( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + + mock_cursor = mock_conn.cursor.return_value + mock_cursor.fetchall.return_value = [] + + result = snowflake_table_retriever._fetch_data(query="""SELECT * FROM test_table""") + + assert result.empty + assert "User does not exist" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + result = snowflake_table_retriever._fetch_data(query="") + + assert result.empty + + def test_serialization_deserialization_pipeline(self) -> None: + + pipeline = Pipeline() + pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) + pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}")) + pipeline.connect("snow.table", "prompt_builder.table") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline