From 25d9dd47255ed9cc56adfa454c9224acdcaad1e4 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 19:01:51 -0500 Subject: [PATCH 01/11] initial commit --- integrations/snowflake/README.md | 0 integrations/snowflake/pyproject.toml | 0 .../retrievers/snowflake/__init__.py | 7 + .../snowflake/snowflake_retriever.py | 329 +++++++++++ integrations/snowflake/tests/__init__.py | 0 .../tests/test_snowflake_retriever.py | 552 ++++++++++++++++++ 6 files changed, 888 insertions(+) create mode 100644 integrations/snowflake/README.md create mode 100644 integrations/snowflake/pyproject.toml create mode 100644 integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py create mode 100644 integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py create mode 100644 integrations/snowflake/tests/__init__.py create mode 100644 integrations/snowflake/tests/test_snowflake_retriever.py diff --git a/integrations/snowflake/README.md b/integrations/snowflake/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py new file mode 100644 index 000000000..dd409ba06 --- /dev/null +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .snowflake_retriever import SnowflakeRetriever + +__all__ = ["SnowflakeRetriever"] diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py new file mode 100644 index 000000000..28a014ecf --- /dev/null +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py @@ -0,0 +1,329 @@ +import re +from typing import Any, Dict, Final, Optional, Union + +import pandas as pd +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace + +with LazyImport("Run 'pip install snowflake-connector-python>=3.10.1'") as snow_import: + import snowflake.connector + from snowflake.connector.connection import SnowflakeConnection + from snowflake.connector.errors import ( + DatabaseError, + ForbiddenError, + ProgrammingError, + ) + +logger = logging.getLogger(__name__) + +MAX_SYS_ROWS: Final = 1000000 # Max rows to fetch from a table + + +@component +class SnowflakeRetriever: + """ + Connects to a Snowflake database to execute a SQL query. + For more information, see [Snowflake documentation](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector). + + ### Usage example: + + ```python + executor = SnowflakeRetriever( + user="", + account="", + api_key=Secret.from_env_var(""), + database="", + db_schema="", + warehouse="", + ) + + # When database and schema are provided during component initialization. + query = "SELECT * FROM table_name" + + # or + + # When database and schema are NOT provided during component initialization. + query = "SELECT * FROM database_name.schema_name.table_name" + + results = executor.run(query=query) + + print(results["dataframe"].head(2)) + # Column 1 Column 2 + # 0 Value1 Value2 + # 1 Value1 Value2 + + print(results["table"]) + # | Column 1 | Column 2 | + # |:----------|:----------| + # | Value1 | Value2 | + # | Value1 | Value2 | + ``` + """ + + def __init__( + self, + user: str, + account: str, + api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + database: Optional[str] = None, + db_schema: Optional[str] = None, + warehouse: Optional[str] = None, + login_timeout: Optional[int] = None, + ) -> None: + """ + :param user: User's login. + :param account: Snowflake account identifier. + :param api_key: Snowflake account password. + :param database: Name of the database to use. + :param db_schema: Name of the schema to use. + :param warehouse: Name of the warehouse to use. + :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + """ + + self.user = user + self.account = account + self.api_key = api_key + self.database = database + self.db_schema = db_schema + self.warehouse = warehouse + self.login_timeout = login_timeout or 60 + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( # type: ignore + self, + user=self.user, + account=self.account, + api_key=self.api_key.to_dict(), + database=self.database, + db_schema=self.db_schema, + warehouse=self.warehouse, + login_timeout=self.login_timeout, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) # type: ignore + + @staticmethod + def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConnection, None]: + """ + Connect to a Snowflake database. + + :param connect_params: Snowflake connection parameters. + """ + try: + return snowflake.connector.connect(**connect_params) + except DatabaseError as e: + logger.error("{error_msg} ", errno=e.errno, error_msg=e.msg) + return None + + @staticmethod + def _extract_table_names(query: str) -> list: + """ + Extract table names from a SQL query using regex. + The extracted table names will be checked for privilege. + + :param query: SQL query to extract table names from. + """ + + # Regular expressions to match table names in various clauses + suffix = "\\s+([a-zA-Z0-9_.]+)" + + patterns = [ + "MERGE\\s+INTO", + "USING", + "JOIN", + "FROM", + "UPDATE", + "DROP\\s+TABLE", + "TRUNCATE\\s+TABLE", + "CREATE\\s+TABLE", + "INSERT\\s+INTO", + "DELETE\\s+FROM", + ] + + # Combine all patterns into a single regex + combined_pattern = "|".join([pattern + suffix for pattern in patterns]) + + # Find all matches in the query + matches = re.findall(pattern=combined_pattern, string=query, flags=re.IGNORECASE) + + # Flatten list of tuples and remove duplication + matches = list(set(sum(matches, ()))) + + # Clean and return unique table names + return [match.strip('`"[]').upper() for match in matches if match] + + @staticmethod + def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: + """ + Execute a SQL query and fetch the results. + + :param conn: An open connection to Snowflake. + :param query: The query to execute. + """ + cur = conn.cursor() + try: + cur.execute(query) + # set a limit to MAX_SYS_ROWS rows to avoid fetching too many rows + rows = cur.fetchmany(size=MAX_SYS_ROWS) + # Convert data to a dataframe + df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) + return df + except ProgrammingError as e: + logger.warning( + "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", + error_msg=e.msg, + sfqid=e.sfqid, + ) + return pd.DataFrame() + + @staticmethod + def _has_select_privilege(privileges: list, table_name: str) -> bool: + """ + Check user's privilege for a specific table. + + :param privileges: List of privileges. + :param table_name: Name of the table. + """ + + for privilege in reversed(privileges): + if table_name.lower() == privilege[3].lower() and re.match( + pattern="truncate|update|insert|delete|operate|references", + string=privilege[1], + flags=re.IGNORECASE, + ): + logger.error("User does not have `Select` privilege on the table.") + return False + + return True + + def _check_privilege( + self, + conn: SnowflakeConnection, + query: str, + user: str, + ) -> bool: + """ + Check whether a user has a `select`-only access to the table. + + :param conn: An open connection to Snowflake. + :param query: The query from where to extract table names to check read-only access. + """ + cur = conn.cursor() + + cur.execute(f"SHOW GRANTS TO USER {user};") + + # Get user's latest role + roles = cur.fetchall() + if not roles: + logger.error("User does not exist") + return False + + # Last row second column from GRANT table + role = roles[-1][1] + + # Get role privilege + cur.execute(f"SHOW GRANTS TO ROLE {role};") + + # Keep table level privileges + table_privileges = [row for row in cur.fetchall() if row[2] == "TABLE"] + + # Get table names to check for privilege + table_names = self._extract_table_names(query=query) + + for table_name in table_names: + if not self._has_select_privilege( + privileges=table_privileges, + table_name=table_name, + ): + return False + return True + + def _fetch_data( + self, + query: str, + ) -> pd.DataFrame: + """ + Fetch data from a database using a SQL query. + + :param query: SQL query to use to fetch the data from the database. Query must be a valid SQL query. + """ + + df = pd.DataFrame() + if not query: + return df + try: + # Create a new connection with every run + conn = self._snowflake_connector( + connect_params={ + "user": self.user, + "account": self.account, + "password": self.api_key.resolve_value(), + "database": self.database, + "schema": self.db_schema, + "warehouse": self.warehouse, + "login_timeout": self.login_timeout, + } + ) + if conn is None: + return df + except (ForbiddenError, ProgrammingError) as e: + logger.error( + "Error connecting to Snowflake ({errno}): {error_msg}", + errno=e.errno, + error_msg=e.msg, + ) + return df + + try: + # Check if user has `select` privilege on the table + if self._check_privilege( + conn=conn, + query=query, + user=self.user, + ): + df = self._execute_sql_query(conn=conn, query=query) + + except Exception as e: + logger.error("An unexpected error has occurred: {error}", error=e) + + # Close connection after every execution + conn.close() + return df + + @component.output_types(dataframe=pd.DataFrame, table=str) + def run(self, query: str) -> Dict[str, Any]: + """ + Execute a SQL query against a Snowflake database. + + :param query: A SQL query to execute. + """ + if not query: + logger.error("Provide a valid SQL query.") + return { + "dataframe": pd.DataFrame, + "table": "", + } + else: + df = self._fetch_data(query) + table_markdown = df.to_markdown(index=False) if not df.empty else "" + + return {"dataframe": df, "table": table_markdown} diff --git a/integrations/snowflake/tests/__init__.py b/integrations/snowflake/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/snowflake/tests/test_snowflake_retriever.py b/integrations/snowflake/tests/test_snowflake_retriever.py new file mode 100644 index 000000000..ba96e7d39 --- /dev/null +++ b/integrations/snowflake/tests/test_snowflake_retriever.py @@ -0,0 +1,552 @@ +from datetime import datetime +from typing import Generator +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +from dateutil.tz import tzlocal +from haystack import Pipeline +from haystack.components.converters import OutputAdapter +from haystack.components.generators import OpenAIGenerator +from haystack.utils import Secret +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from pytest import LogCaptureFixture +from snowflake.connector.errors import DatabaseError, ForbiddenError, ProgrammingError + +from haystack_integrations.components.retrievers.snowflake_retriever import MAX_SYS_ROWS, SnowflakeRetriever + + +class TestSnowflakeRetriever: + @pytest.fixture + def snowflake_retriever(self) -> SnowflakeRetriever: + return SnowflakeRetriever( + user="test_user", + account="test_account", + api_key=Secret.from_token("test-api-key"), + database="test_database", + db_schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_snowflake_connector(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + + conn = snowflake_retriever._snowflake_connector( + connect_params={ + "user": "test_user", + "account": "test_account", + "api_key": Secret.from_token("test-api-key"), + "database": "test_database", + "schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 30, + } + ) + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + api_key=Secret.from_token("test-api-key"), + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + assert conn == mock_conn + + def test_query_is_empty(self, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture) -> None: + query = "" + result = snowflake_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + assert "Provide a valid SQL query" in caplog.text + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_exception( + self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + + mock_connect = mock_connect.return_value + mock_connect._fetch_data.side_effect = Exception("Unknown error") + + query = 4 + result = snowflake_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + + assert "An unexpected error has occurred" in caplog.text + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_forbidden_error_during_connection( + self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect.side_effect = ForbiddenError(msg="Forbidden error", errno=403) + + result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "000403: Forbidden error" in caplog.text + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_programing_error_during_connection( + self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect.side_effect = ProgrammingError(msg="Programming error", errno=403) + + result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "000403: Programming error" in caplog.text + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_execute_sql_query_programming_error( + self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + + mock_conn = MagicMock() + mock_cursor = mock_conn.cursor.return_value + + mock_cursor.execute.side_effect = ProgrammingError(msg="Simulated programming error", sfqid="ABC-123") + + result = snowflake_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") + + assert result.empty + + assert ( + "Simulated programming error Use the following ID to check the status of " + "the query in Snowflake UI (ID: ABC-123)" in caplog.text + ) + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_run_connection_error(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + mock_connect.side_effect = DatabaseError(msg="Connection error", errno=1234) + + query = "SELECT * FROM test_table" + result = snowflake_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + + def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever) -> None: + queries = [ + "SELECT * FROM table_a", + "SELECT name, value FROM (SELECT name, value FROM table_a) AS subquery", + "SELECT name, value FROM (SELECT name, value FROM table_a ) AS subquery", + "UPDATE table_a SET value = 'new_value' WHERE id = 1", + "INSERT INTO table_a (id, name, value) VALUES (1, 'name1', 'value1')", + "DELETE FROM table_a WHERE id = 1", + "TRUNCATE TABLE table_a", + "DROP TABLE table_a", + ] + for query in queries: + result = snowflake_retriever._extract_table_names(query) + assert result == ["TABLE_A"] + + def test_extract_database_and_schema_from_query(self, snowflake_retriever: SnowflakeRetriever) -> None: + + # when database and schema are next to table name + assert snowflake_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == ["DB.SCHEMA.TABLE_A"] + # No database + assert snowflake_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == ["SCHEMA.TABLE_A"] + + def test_extract_multiple_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: + queries = [ + "MERGE INTO table_a USING table_b ON table_a.id = table_b.id WHEN MATCHED", + "SELECT a.name, b.value FROM table_a AS a FULL OUTER JOIN table_b AS b ON a.id = b.id", + "SELECT a.name, b.value FROM table_a AS a RIGHT JOIN table_b AS b ON a.id = b.id", + ] + for query in queries: + result = snowflake_retriever._extract_table_names(query) + # Due to using set when deduplicating + assert result == ["TABLE_A", "TABLE_B"] or ["TABLE_B", "TABLE_A"] + + def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: + + assert snowflake_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a """ + + """FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) == ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"] or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B" ""] + # No database + assert snowflake_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a """ + + """FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_cursor.description = [mock_col1, mock_col2] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM test_table" + expected = pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}) + result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) + + mock_cursor.execute.assert_called_once_with(query) + mock_cursor.fetchmany.assert_called_once_with(size=MAX_SYS_ROWS) + + assert result.equals(expected) + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_is_select_only( + self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "LOCATIONS", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], # Table privileges + ] + + query = "select * from locations" + result = snowflake_retriever._check_privilege( + conn=mock_conn, user="test_user", query=query, database_name="test_database", schema_name="test_schema" + ) + assert result + + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "INSERT", + "TABLE", + "LOCATIONS", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + + result = snowflake_retriever._check_privilege( + conn=mock_conn, user="test_user", query=query, database_name="test_database", schema_name="test_schema" + ) + print(result) + assert not result + assert "User does not have `Select` privilege on the table" in caplog.text + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_col1.name = "id" + mock_col2.name = "year" + mock_cursor.fetchmany.return_value = [(1233, 1998)] + mock_cursor.description = [mock_col1, mock_col2] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT id, extract(year from date_col) as year FROM test_table" + expected = pd.DataFrame(data={"id": [1233], "year": [1998]}) + result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) + mock_cursor.execute.assert_called_once_with(query) + mock_cursor.fetchmany.assert_called_once_with(size=MAX_SYS_ROWS) + + assert result.equals(expected) + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "locations", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.description = [mock_col1, mock_col2] + + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM locations" + + expected = { + "dataframe": pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}), + "table": "| City | State |\n|:--------|:---------|\n| Chicago | Illinois |", + } + + result = snowflake_retriever.run(query=query) + print(result["table"]) + assert result["dataframe"].equals(expected["dataframe"]) + assert result["table"] == expected["table"] + + @pytest.fixture + def mock_chat_completion(self) -> Generator: + """ + Mock the OpenAI API completion response and reuse it for tests + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4o", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + logprobs=None, + index=0, + message=ChatCompletionMessage(content="select locations from table_a", role="assistant"), + ) + ], + created=int(datetime.now(tz=tzlocal()).timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_run_pipeline( + self, mock_connect: MagicMock, mock_chat_completion: MagicMock, snowflake_retriever: SnowflakeRetriever + ) -> None: + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "test_database.test_schema.table_a", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "locations" + + mock_cursor.description = [mock_col1] + + mock_cursor.fetchmany.return_value = [("Chicago",), ("Miami",), ("Berlin",)] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + expected = { + "dataframe": pd.DataFrame(data={"locations": ["Chicago", "Miami", "Berlin"]}), + "table": "| locations |\n|:------------|\n| Chicago |\n| Miami |\n| Berlin |", + } + + llm = OpenAIGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key")) + adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) + pipeline = Pipeline() + + pipeline.add_component("llm", llm) + pipeline.add_component("adapter", adapter) + pipeline.add_component("snowflake", snowflake_retriever) + + pipeline.connect(sender="llm.replies", receiver="adapter.replies") + pipeline.connect(sender="adapter.output", receiver="snowflake.query") + + result = pipeline.run(data={"llm": {"prompt": "Generate a SQL query that extract all locations from table_a"}}) + + assert result["snowflake"]["dataframe"].equals(expected["dataframe"]) + assert result["snowflake"]["table"] == expected["table"] + + def test_from_dict(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + data = { + "type": "deepset_cloud_custom_nodes.augmenters.snowflake_retriever.SnowflakeRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "test_user", + "account": "new_account", + "database": "test_database", + "db_schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 3, + }, + } + component = SnowflakeRetriever.from_dict(data) + + assert component.user == "test_user" + assert component.account == "new_account" + assert component.api_key == Secret.from_env_var("SNOWFLAKE_API_KEY") + assert component.database == "test_database" + assert component.db_schema == "test_schema" + assert component.warehouse == "test_warehouse" + assert component.login_timeout == 3 + + def test_to_dict_default(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + component = SnowflakeRetriever( + user="test_user", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + account="test_account", + database="test_database", + db_schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + data = component.to_dict() + + assert data == { + "type": "deepset_cloud_custom_nodes.augmenters.snowflake_retriever.SnowflakeRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "test_user", + "account": "test_account", + "database": "test_database", + "db_schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 30, + }, + } + + def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + component = SnowflakeRetriever( + user="John", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + account="TGMD-EEREW", + database="CITY", + db_schema="SMALL_TOWNS", + warehouse="COMPUTE_WH", + login_timeout=30, + ) + + data = component.to_dict() + + assert data == { + "type": "deepset_cloud_custom_nodes.augmenters.snowflake_retriever.SnowflakeRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "John", + "account": "TGMD-EEREW", + "database": "CITY", + "db_schema": "SMALL_TOWNS", + "warehouse": "COMPUTE_WH", + "login_timeout": 30, + }, + } + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_has_select_privilege( + self, mock_logger: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + # Define test cases + test_cases = [ + # Test case 1: Fully qualified table name in query + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 2: Schema and table names in query, database name as argument + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 3: Only table name in query, database and schema names as arguments + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 5: Privilege does not match + { + "privileges": [[None, "INSERT", None, "table"]], + "table_name": "table", + "expected_result": False, + }, + # Test case 6: Case-insensitive match + { + "privileges": [[None, "select", None, "table"]], + "table_name": "TABLE", + "expected_result": True, + }, + ] + + for case in test_cases: + result = snowflake_retriever._has_select_privilege( + privileges=case["privileges"], # type: ignore + table_name=case["table_name"], # type: ignore + ) + assert result == case["expected_result"] # type: ignore + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_user_does_not_exist( + self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + ) -> None: + + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + + mock_cursor = mock_conn.cursor.return_value + mock_cursor.fetchall.return_value = [] + + result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "User does not exist" in caplog.text + + @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + def test_empty_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + + result = snowflake_retriever._fetch_data(query="") + + assert result.empty \ No newline at end of file From d1231efee27307cef05ed77fa4f92a1183495eec Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 19:02:32 -0500 Subject: [PATCH 02/11] add unit tests --- integrations/snowflake/tests/__init__.py | 3 + .../tests/test_snowflake_retriever.py | 56 +++++++++---------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/integrations/snowflake/tests/__init__.py b/integrations/snowflake/tests/__init__.py index e69de29bb..6b5e14dc1 100644 --- a/integrations/snowflake/tests/__init__.py +++ b/integrations/snowflake/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/snowflake/tests/test_snowflake_retriever.py b/integrations/snowflake/tests/test_snowflake_retriever.py index ba96e7d39..98832af72 100644 --- a/integrations/snowflake/tests/test_snowflake_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_retriever.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from datetime import datetime from typing import Generator from unittest.mock import MagicMock, patch @@ -14,7 +18,7 @@ from pytest import LogCaptureFixture from snowflake.connector.errors import DatabaseError, ForbiddenError, ProgrammingError -from haystack_integrations.components.retrievers.snowflake_retriever import MAX_SYS_ROWS, SnowflakeRetriever +from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever class TestSnowflakeRetriever: @@ -30,7 +34,7 @@ def snowflake_retriever(self) -> SnowflakeRetriever: login_timeout=30, ) - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_snowflake_connector(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: mock_conn = MagicMock() mock_connect.return_value = mock_conn @@ -66,7 +70,7 @@ def test_query_is_empty(self, snowflake_retriever: SnowflakeRetriever, caplog: L assert result["dataframe"].empty assert "Provide a valid SQL query" in caplog.text - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_exception( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -82,7 +86,7 @@ def test_exception( assert "An unexpected error has occurred" in caplog.text - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_forbidden_error_during_connection( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -93,7 +97,7 @@ def test_forbidden_error_during_connection( assert result.empty assert "000403: Forbidden error" in caplog.text - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_programing_error_during_connection( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -104,7 +108,7 @@ def test_programing_error_during_connection( assert result.empty assert "000403: Programming error" in caplog.text - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_execute_sql_query_programming_error( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -123,7 +127,7 @@ def test_execute_sql_query_programming_error( "the query in Snowflake UI (ID: ABC-123)" in caplog.text ) - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_run_connection_error(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: mock_connect.side_effect = DatabaseError(msg="Connection error", errno=1234) @@ -151,7 +155,9 @@ def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever def test_extract_database_and_schema_from_query(self, snowflake_retriever: SnowflakeRetriever) -> None: # when database and schema are next to table name - assert snowflake_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == ["DB.SCHEMA.TABLE_A"] + assert snowflake_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ + "DB.SCHEMA.TABLE_A" + ] # No database assert snowflake_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == ["SCHEMA.TABLE_A"] @@ -178,7 +184,7 @@ def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: + """FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" ) == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -196,11 +202,10 @@ def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: S result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) mock_cursor.execute.assert_called_once_with(query) - mock_cursor.fetchmany.assert_called_once_with(size=MAX_SYS_ROWS) assert result.equals(expected) - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_is_select_only( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -225,9 +230,7 @@ def test_is_select_only( ] query = "select * from locations" - result = snowflake_retriever._check_privilege( - conn=mock_conn, user="test_user", query=query, database_name="test_database", schema_name="test_schema" - ) + result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) assert result mock_cursor.fetchall.side_effect = [ @@ -246,14 +249,12 @@ def test_is_select_only( ], ] - result = snowflake_retriever._check_privilege( - conn=mock_conn, user="test_user", query=query, database_name="test_database", schema_name="test_schema" - ) + result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) print(result) assert not result assert "User does not have `Select` privilege on the table" in caplog.text - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -270,11 +271,10 @@ def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: S expected = pd.DataFrame(data={"id": [1233], "year": [1998]}) result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) mock_cursor.execute.assert_called_once_with(query) - mock_cursor.fetchmany.assert_called_once_with(size=MAX_SYS_ROWS) assert result.equals(expected) - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -340,7 +340,7 @@ def mock_chat_completion(self) -> Generator: mock_chat_completion_create.return_value = completion yield mock_chat_completion_create - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_run_pipeline( self, mock_connect: MagicMock, mock_chat_completion: MagicMock, snowflake_retriever: SnowflakeRetriever ) -> None: @@ -395,7 +395,7 @@ def test_run_pipeline( def test_from_dict(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") data = { - "type": "deepset_cloud_custom_nodes.augmenters.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -435,7 +435,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "deepset_cloud_custom_nodes.augmenters.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -467,7 +467,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "deepset_cloud_custom_nodes.augmenters.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -483,7 +483,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: }, } - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_has_select_privilege( self, mock_logger: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -528,7 +528,7 @@ def test_has_select_privilege( ) assert result == case["expected_result"] # type: ignore - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_user_does_not_exist( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: @@ -544,9 +544,9 @@ def test_user_does_not_exist( assert result.empty assert "User does not exist" in caplog.text - @patch("deepset_cloud_custom_nodes.augmenters.snowflake_retriever.snowflake.connector.connect") + @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_empty_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: result = snowflake_retriever._fetch_data(query="") - assert result.empty \ No newline at end of file + assert result.empty From c787c371758019dd02e64cf08822577735407fa6 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 19:02:59 -0500 Subject: [PATCH 03/11] add pyproject.toml --- integrations/snowflake/pyproject.toml | 146 ++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml index e69de29bb..63434d1bf 100644 --- a/integrations/snowflake/pyproject.toml +++ b/integrations/snowflake/pyproject.toml @@ -0,0 +1,146 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "snowflake-haystack" +dynamic = ["version"] +description = 'A Snowflake integration for the Haystack framework.' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", + "snowflake-connector-python>=3.10.1", + "tabulate>=0.9.0"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/snowflake-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +all = ["style", "typing"] + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["snowflake_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*"] +ignore_missing_imports = true \ No newline at end of file From c20f1e20a0930f0739040e6aae4c9377bf4c19ab Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 21:16:03 -0500 Subject: [PATCH 04/11] add pydoc config --- integrations/snowflake/pydoc/config.yml | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 integrations/snowflake/pydoc/config.yml diff --git a/integrations/snowflake/pydoc/config.yml b/integrations/snowflake/pydoc/config.yml new file mode 100644 index 000000000..7237b3816 --- /dev/null +++ b/integrations/snowflake/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: + [ + "haystack_integrations.components.retrievers.snowflake.snowflake_retriever" + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Snowflake integration for Haystack + category_slug: integrations-api + title: Snowflake + slug: integrations-Snowflake + order: 130 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_snowflake.md \ No newline at end of file From 32e62f500f506a2a897db35ad2d7e913629a46f7 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 21:22:09 -0500 Subject: [PATCH 05/11] add CHANGELOG file --- integrations/snowflake/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 integrations/snowflake/CHANGELOG.md diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md new file mode 100644 index 000000000..c84a3b08e --- /dev/null +++ b/integrations/snowflake/CHANGELOG.md @@ -0,0 +1 @@ +## [integrations/snowflake-v0.0.0] - 2024-09-06 \ No newline at end of file From 1f8e7668fa0bbf3ba9ec1dd34c8a291dca6f9af0 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 21:26:22 -0500 Subject: [PATCH 06/11] update pyproject.toml --- integrations/snowflake/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml index 63434d1bf..e05b1f498 100644 --- a/integrations/snowflake/pyproject.toml +++ b/integrations/snowflake/pyproject.toml @@ -142,5 +142,5 @@ show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] -module = ["haystack.*", "haystack_integrations.*", "pytest.*"] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "openai.*", "snowflake.*"] ignore_missing_imports = true \ No newline at end of file From fd7c6431397eb8625dcfc9dfd35faee0db74f96a Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 7 Sep 2024 21:27:01 -0500 Subject: [PATCH 07/11] lint file --- .../components/retrievers/snowflake/snowflake_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py index 28a014ecf..85e573300 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py @@ -96,7 +96,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( # type: ignore + return default_to_dict( self, user=self.user, account=self.account, @@ -119,7 +119,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeRetriever": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) - return default_from_dict(cls, data) # type: ignore + return default_from_dict(cls, data) @staticmethod def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConnection, None]: From be30876ddd92749bec90fdd692a3d3253414cea6 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sun, 8 Sep 2024 15:15:37 -0500 Subject: [PATCH 08/11] add example and fix lint --- integrations/snowflake/README.md | 23 ++++ .../snowflake/example/text2sql_example.py | 120 ++++++++++++++++++ integrations/snowflake/pyproject.toml | 11 +- .../snowflake/snowflake_retriever.py | 2 +- .../tests/test_snowflake_retriever.py | 43 +++---- 5 files changed, 172 insertions(+), 27 deletions(-) create mode 100644 integrations/snowflake/example/text2sql_example.py diff --git a/integrations/snowflake/README.md b/integrations/snowflake/README.md index e69de29bb..30f0aee1a 100644 --- a/integrations/snowflake/README.md +++ b/integrations/snowflake/README.md @@ -0,0 +1,23 @@ +# snowflake-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [License](#license) + +## Installation + +```console +pip install snowflake-haystack +``` +## Examples +You can find a code example showing how to use the Retriever under the `example/` folder of this repo. + +## License + +`snowflake-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. \ No newline at end of file diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py new file mode 100644 index 000000000..8ff9d409e --- /dev/null +++ b/integrations/snowflake/example/text2sql_example.py @@ -0,0 +1,120 @@ +from dotenv import load_dotenv +from haystack import Pipeline +from haystack.components.builders import PromptBuilder +from haystack.components.converters import OutputAdapter +from haystack.components.generators import OpenAIGenerator +from haystack.utils import Secret + +from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever + +load_dotenv() + +sql_template = """ + You are a SQL expert working with Snowflake. + + Your task is to create a Snowflake SQL query for the given question. + + Refrain from explaining your answer. Your answer must be the SQL query + in plain text format without using Markdown. + + Here are some relevant tables, a description about it, and their + columns: + + Database name: DEMO_DB + Schema name: ADVENTURE_WORKS + Table names: + - ADDRESS: Employees Address Table + - EMPLOYEE: Employees directory + - SALESTERRITORY: Sales territory lookup table. + - SALESORDERHEADER: General sales order information. + + User's question: {{ question }} + + Generated SQL query: +""" + +sql_builder = PromptBuilder(template=sql_template) + +analyst_template = """ + You are an expert data analyst. + + Your role is to answer the user's question {{ question }} using the information + in the table. + + You will base your response solely on the information provided in the + table. + + Do not rely on your knowledge base; only the data that is in the table. + + Refrain from using the term "table" in your response, but instead, use + the word "data" + + If the table is blank say: + + "The specific answer can't be found in the database. Try rephrasing your + question." + + Additionally, you will present the table in a tabular format and provide + the SQL query used to extract the relevant rows from the database in + Markdown. + + If the table is larger than 10 rows, display the most important rows up + to 10 rows. Your answer must be detailed and provide insights based on + the question and the available data. + + SQL query: + + {{ sql_query }} + + Table: + + {{ table }} + + Answer: +""" + +analyst_builder = PromptBuilder(template=analyst_template) + +# Model responsible for generating the SQL query +sql_llm = OpenAIGenerator( + model="gpt-4o", + api_key=Secret.from_env_var("OPENAI_API_KEY"), + generation_kwargs={"temperature": 0.0, "max_tokens": 1000}, +) + +# Model responsible for analyzing the table +analyst_llm = OpenAIGenerator( + model="gpt-4o", + api_key=Secret.from_env_var("OPENAI_API_KEY"), + generation_kwargs={"temperature": 0.0, "max_tokens": 2000}, +) + +snowflake = SnowflakeRetriever( + user="", + account="", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + warehouse="", +) + +adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) + +pipeline = Pipeline() + +pipeline.add_component(name="sql_builder", instance=sql_builder) +pipeline.add_component(name="sql_llm", instance=sql_llm) +pipeline.add_component(name="adapter", instance=adapter) +pipeline.add_component(name="snowflake", instance=snowflake) +pipeline.add_component(name="analyst_builder", instance=analyst_builder) +pipeline.add_component(name="analyst_llm", instance=analyst_llm) + + +pipeline.connect("sql_builder.prompt", "sql_llm.prompt") +pipeline.connect("sql_llm.replies", "adapter.replies") +pipeline.connect("adapter.output", "snowflake.query") +pipeline.connect("snowflake.table", "analyst_builder.table") +pipeline.connect("adapter.output", "analyst_builder.sql_query") +pipeline.connect("analyst_builder.prompt", "analyst_llm.prompt") + +question = "What are my top territories by number of orders and by sales value?" + +response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml index e05b1f498..68f9ec477 100644 --- a/integrations/snowflake/pyproject.toml +++ b/integrations/snowflake/pyproject.toml @@ -10,7 +10,8 @@ readme = "README.md" requires-python = ">=3.8" license = "Apache-2.0" keywords = [] -authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Mohamed Sriha", email = "mohamed.sriha@deepset.ai" }] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -23,9 +24,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", - "snowflake-connector-python>=3.10.1", - "tabulate>=0.9.0"] +dependencies = ["haystack-ai", "snowflake-connector-python>=3.10.1", "tabulate>=0.9.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake#readme" @@ -114,6 +113,10 @@ ignore = [ "PLR0912", "PLR0913", "PLR0915", + # Ignore SQL injection + "S608", + # Unused method argument + "ARG002" ] unfixable = [ # Don't touch unused imports diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py index 85e573300..e874631de 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py @@ -32,7 +32,7 @@ class SnowflakeRetriever: executor = SnowflakeRetriever( user="", account="", - api_key=Secret.from_env_var(""), + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), database="", db_schema="", warehouse="", diff --git a/integrations/snowflake/tests/test_snowflake_retriever.py b/integrations/snowflake/tests/test_snowflake_retriever.py index 98832af72..c3d748086 100644 --- a/integrations/snowflake/tests/test_snowflake_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_retriever.py @@ -74,7 +74,6 @@ def test_query_is_empty(self, snowflake_retriever: SnowflakeRetriever, caplog: L def test_exception( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: - mock_connect = mock_connect.return_value mock_connect._fetch_data.side_effect = Exception("Unknown error") @@ -112,7 +111,6 @@ def test_programing_error_during_connection( def test_execute_sql_query_programming_error( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: - mock_conn = MagicMock() mock_cursor = mock_conn.cursor.return_value @@ -153,7 +151,6 @@ def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever assert result == ["TABLE_A"] def test_extract_database_and_schema_from_query(self, snowflake_retriever: SnowflakeRetriever) -> None: - # when database and schema are next to table name assert snowflake_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ "DB.SCHEMA.TABLE_A" @@ -173,16 +170,23 @@ def test_extract_multiple_table_names(self, snowflake_retriever: SnowflakeRetrie assert result == ["TABLE_A", "TABLE_B"] or ["TABLE_B", "TABLE_A"] def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: - - assert snowflake_retriever._extract_table_names( - query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a """ - + """FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" - ) == ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"] or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B" ""] + assert ( + snowflake_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a + FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) + == ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"] + or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B"] + ) # No database - assert snowflake_retriever._extract_table_names( - query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a """ - + """FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" - ) == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] + assert ( + snowflake_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a + FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) + == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] + or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] + ) @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: @@ -250,7 +254,7 @@ def test_is_select_only( ] result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) - print(result) + assert not result assert "User does not have `Select` privilege on the table" in caplog.text @@ -311,7 +315,7 @@ def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetrie } result = snowflake_retriever.run(query=query) - print(result["table"]) + assert result["dataframe"].equals(expected["dataframe"]) assert result["table"] == expected["table"] @@ -344,7 +348,6 @@ def mock_chat_completion(self) -> Generator: def test_run_pipeline( self, mock_connect: MagicMock, mock_chat_completion: MagicMock, snowflake_retriever: SnowflakeRetriever ) -> None: - mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -484,9 +487,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: } @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_has_select_privilege( - self, mock_logger: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture - ) -> None: + def test_has_select_privilege(self, mock_logger: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: # Define test cases test_cases = [ # Test case 1: Fully qualified table name in query @@ -532,21 +533,19 @@ def test_has_select_privilege( def test_user_does_not_exist( self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture ) -> None: - mock_conn = MagicMock() mock_connect.return_value = mock_conn mock_cursor = mock_conn.cursor.return_value mock_cursor.fetchall.return_value = [] - result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + result = snowflake_retriever._fetch_data(query="""SELECT * FROM test_table""") assert result.empty assert "User does not exist" in caplog.text @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_empty_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: - + def test_empty_query(self, snowflake_retriever: SnowflakeRetriever) -> None: result = snowflake_retriever._fetch_data(query="") assert result.empty From 8fef40a7aa24787c0a6db433fd63e8cbd9bb7117 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sun, 8 Sep 2024 15:31:38 -0500 Subject: [PATCH 09/11] update comments --- integrations/snowflake/example/text2sql_example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py index 8ff9d409e..3fe371d87 100644 --- a/integrations/snowflake/example/text2sql_example.py +++ b/integrations/snowflake/example/text2sql_example.py @@ -75,14 +75,14 @@ analyst_builder = PromptBuilder(template=analyst_template) -# Model responsible for generating the SQL query +# LLM responsible for generating the SQL query sql_llm = OpenAIGenerator( model="gpt-4o", api_key=Secret.from_env_var("OPENAI_API_KEY"), generation_kwargs={"temperature": 0.0, "max_tokens": 1000}, ) -# Model responsible for analyzing the table +# LLM responsible for analyzing the table analyst_llm = OpenAIGenerator( model="gpt-4o", api_key=Secret.from_env_var("OPENAI_API_KEY"), @@ -117,4 +117,4 @@ question = "What are my top territories by number of orders and by sales value?" -response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) +response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) \ No newline at end of file From 7a6c49b1b5420edaa63923de16a88a81372b82a8 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sun, 8 Sep 2024 15:35:27 -0500 Subject: [PATCH 10/11] add header and trailing line --- integrations/snowflake/example/text2sql_example.py | 2 +- .../components/retrievers/snowflake/snowflake_retriever.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py index 3fe371d87..8b47b8f6c 100644 --- a/integrations/snowflake/example/text2sql_example.py +++ b/integrations/snowflake/example/text2sql_example.py @@ -117,4 +117,4 @@ question = "What are my top territories by number of orders and by sales value?" -response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) \ No newline at end of file +response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py index e874631de..0aa2d5a48 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 import re from typing import Any, Dict, Final, Optional, Union From 89ac0c9d56e6f513bfcacb7c6472a91d8b020724 Mon Sep 17 00:00:00 2001 From: medsriha Date: Sat, 14 Sep 2024 20:24:34 -0500 Subject: [PATCH 11/11] update based on review --- integrations/snowflake/CHANGELOG.md | 2 +- .../snowflake/example/text2sql_example.py | 4 +- .../retrievers/snowflake/__init__.py | 4 +- ...riever.py => snowflake_table_retriever.py} | 47 +++-- ...r.py => test_snowflake_table_retriever.py} | 194 ++++++++++++------ 5 files changed, 157 insertions(+), 94 deletions(-) rename integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/{snowflake_retriever.py => snowflake_table_retriever.py} (89%) rename integrations/snowflake/tests/{test_snowflake_retriever.py => test_snowflake_table_retriever.py} (70%) diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md index c84a3b08e..0553a3f4b 100644 --- a/integrations/snowflake/CHANGELOG.md +++ b/integrations/snowflake/CHANGELOG.md @@ -1 +1 @@ -## [integrations/snowflake-v0.0.0] - 2024-09-06 \ No newline at end of file +## [integrations/snowflake-v0.0.1] - 2024-09-06 \ No newline at end of file diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py index 8b47b8f6c..b85a4c677 100644 --- a/integrations/snowflake/example/text2sql_example.py +++ b/integrations/snowflake/example/text2sql_example.py @@ -5,7 +5,7 @@ from haystack.components.generators import OpenAIGenerator from haystack.utils import Secret -from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever load_dotenv() @@ -89,7 +89,7 @@ generation_kwargs={"temperature": 0.0, "max_tokens": 2000}, ) -snowflake = SnowflakeRetriever( +snowflake = SnowflakeTableRetriever( user="", account="", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py index dd409ba06..294d3cce4 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .snowflake_retriever import SnowflakeRetriever +from .snowflake_table_retriever import SnowflakeTableRetriever -__all__ = ["SnowflakeRetriever"] +__all__ = ["SnowflakeTableRetriever"] diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py similarity index 89% rename from integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py rename to integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index 0aa2d5a48..aa6f5ff4d 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -24,7 +24,7 @@ @component -class SnowflakeRetriever: +class SnowflakeTableRetriever: """ Connects to a Snowflake database to execute a SQL query. For more information, see [Snowflake documentation](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector). @@ -32,7 +32,7 @@ class SnowflakeRetriever: ### Usage example: ```python - executor = SnowflakeRetriever( + executor = SnowflakeTableRetriever( user="", account="", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), @@ -51,12 +51,12 @@ class SnowflakeRetriever: results = executor.run(query=query) - print(results["dataframe"].head(2)) + print(results["dataframe"].head(2)) # Pandas dataframe # Column 1 Column 2 # 0 Value1 Value2 # 1 Value1 Value2 - print(results["table"]) + print(results["table"]) # Markdown # | Column 1 | Column 2 | # |:----------|:----------| # | Value1 | Value2 | @@ -111,7 +111,7 @@ def to_dict(self) -> Dict[str, Any]: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeRetriever": + def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever": """ Deserializes the component from a dictionary. @@ -140,14 +140,13 @@ def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConne @staticmethod def _extract_table_names(query: str) -> list: """ - Extract table names from a SQL query using regex. + Extract table names from an SQL query using regex. The extracted table names will be checked for privilege. :param query: SQL query to extract table names from. """ - # Regular expressions to match table names in various clauses - suffix = "\\s+([a-zA-Z0-9_.]+)" + suffix = "\\s+([a-zA-Z0-9_.]+)" # Regular expressions to match table names in various clauses patterns = [ "MERGE\\s+INTO", @@ -168,7 +167,7 @@ def _extract_table_names(query: str) -> list: # Find all matches in the query matches = re.findall(pattern=combined_pattern, string=query, flags=re.IGNORECASE) - # Flatten list of tuples and remove duplication + # Flatten the list of tuples and remove duplication matches = list(set(sum(matches, ()))) # Clean and return unique table names @@ -177,7 +176,7 @@ def _extract_table_names(query: str) -> list: @staticmethod def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: """ - Execute a SQL query and fetch the results. + Execute an SQL query and fetch the results. :param conn: An open connection to Snowflake. :param query: The query to execute. @@ -185,18 +184,21 @@ def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: cur = conn.cursor() try: cur.execute(query) - # set a limit to MAX_SYS_ROWS rows to avoid fetching too many rows - rows = cur.fetchmany(size=MAX_SYS_ROWS) - # Convert data to a dataframe - df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) + rows = cur.fetchmany(size=MAX_SYS_ROWS) # set a limit to avoid fetching too many rows + + df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) # Convert data to a dataframe return df - except ProgrammingError as e: - logger.warning( - "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", - error_msg=e.msg, - sfqid=e.sfqid, - ) - return pd.DataFrame() + except Exception as e: + if isinstance(e, ProgrammingError): + logger.warning( + "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", + error_msg=e.msg, + sfqid=e.sfqid, + ) + else: + logger.warning("An unexpected error occurred: {error_msg}", error_msg=e) + + return pd.DataFrame() @staticmethod def _has_select_privilege(privileges: list, table_name: str) -> bool: @@ -213,7 +215,6 @@ def _has_select_privilege(privileges: list, table_name: str) -> bool: string=privilege[1], flags=re.IGNORECASE, ): - logger.error("User does not have `Select` privilege on the table.") return False return True @@ -304,6 +305,8 @@ def _fetch_data( user=self.user, ): df = self._execute_sql_query(conn=conn, query=query) + else: + logger.error("User does not have `Select` privilege on the table.") except Exception as e: logger.error("An unexpected error has occurred: {error}", error=e) diff --git a/integrations/snowflake/tests/test_snowflake_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py similarity index 70% rename from integrations/snowflake/tests/test_snowflake_retriever.py rename to integrations/snowflake/tests/test_snowflake_table_retriever.py index c3d748086..547f7e1b1 100644 --- a/integrations/snowflake/tests/test_snowflake_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -12,19 +12,20 @@ from haystack import Pipeline from haystack.components.converters import OutputAdapter from haystack.components.generators import OpenAIGenerator +from haystack.components.builders import PromptBuilder from haystack.utils import Secret from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from pytest import LogCaptureFixture from snowflake.connector.errors import DatabaseError, ForbiddenError, ProgrammingError -from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever -class TestSnowflakeRetriever: +class TestSnowflakeTableRetriever: @pytest.fixture - def snowflake_retriever(self) -> SnowflakeRetriever: - return SnowflakeRetriever( + def snowflake_table_retriever(self) -> SnowflakeTableRetriever: + return SnowflakeTableRetriever( user="test_user", account="test_account", api_key=Secret.from_token("test-api-key"), @@ -34,12 +35,16 @@ def snowflake_retriever(self) -> SnowflakeRetriever: login_timeout=30, ) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_snowflake_connector(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_snowflake_connector( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_conn = MagicMock() mock_connect.return_value = mock_conn - conn = snowflake_retriever._snowflake_connector( + conn = snowflake_table_retriever._snowflake_connector( connect_params={ "user": "test_user", "account": "test_account", @@ -62,61 +67,71 @@ def test_snowflake_connector(self, mock_connect: MagicMock, snowflake_retriever: assert conn == mock_conn - def test_query_is_empty(self, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture) -> None: + def test_query_is_empty( + self, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: query = "" - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["table"] == "" assert result["dataframe"].empty assert "Provide a valid SQL query" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_exception( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_connect = mock_connect.return_value mock_connect._fetch_data.side_effect = Exception("Unknown error") query = 4 - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["table"] == "" assert result["dataframe"].empty assert "An unexpected error has occurred" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_forbidden_error_during_connection( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_connect.side_effect = ForbiddenError(msg="Forbidden error", errno=403) - result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") assert result.empty assert "000403: Forbidden error" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_programing_error_during_connection( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_connect.side_effect = ProgrammingError(msg="Programming error", errno=403) - result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") assert result.empty assert "000403: Programming error" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_execute_sql_query_programming_error( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_conn = MagicMock() mock_cursor = mock_conn.cursor.return_value mock_cursor.execute.side_effect = ProgrammingError(msg="Simulated programming error", sfqid="ABC-123") - result = snowflake_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") + result = snowflake_table_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") assert result.empty @@ -125,17 +140,21 @@ def test_execute_sql_query_programming_error( "the query in Snowflake UI (ID: ABC-123)" in caplog.text ) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_run_connection_error(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_connection_error( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_connect.side_effect = DatabaseError(msg="Connection error", errno=1234) query = "SELECT * FROM test_table" - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["table"] == "" assert result["dataframe"].empty - def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_single_table_name(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: queries = [ "SELECT * FROM table_a", "SELECT name, value FROM (SELECT name, value FROM table_a) AS subquery", @@ -147,31 +166,35 @@ def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever "DROP TABLE table_a", ] for query in queries: - result = snowflake_retriever._extract_table_names(query) + result = snowflake_table_retriever._extract_table_names(query) assert result == ["TABLE_A"] - def test_extract_database_and_schema_from_query(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_database_and_schema_from_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: # when database and schema are next to table name - assert snowflake_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ "DB.SCHEMA.TABLE_A" ] # No database - assert snowflake_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == ["SCHEMA.TABLE_A"] + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == [ + "SCHEMA.TABLE_A" + ] - def test_extract_multiple_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_multiple_table_names(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: queries = [ "MERGE INTO table_a USING table_b ON table_a.id = table_b.id WHEN MATCHED", "SELECT a.name, b.value FROM table_a AS a FULL OUTER JOIN table_b AS b ON a.id = b.id", "SELECT a.name, b.value FROM table_a AS a RIGHT JOIN table_b AS b ON a.id = b.id", ] for query in queries: - result = snowflake_retriever._extract_table_names(query) + result = snowflake_table_retriever._extract_table_names(query) # Due to using set when deduplicating assert result == ["TABLE_A", "TABLE_B"] or ["TABLE_B", "TABLE_A"] - def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_multiple_db_schema_from_table_names( + self, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: assert ( - snowflake_retriever._extract_table_names( + snowflake_table_retriever._extract_table_names( query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" ) @@ -180,7 +203,7 @@ def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: ) # No database assert ( - snowflake_retriever._extract_table_names( + snowflake_table_retriever._extract_table_names( query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" ) @@ -188,8 +211,12 @@ def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] ) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_execute_sql_query( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -203,15 +230,17 @@ def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: S query = "SELECT * FROM test_table" expected = pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}) - result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) mock_cursor.execute.assert_called_once_with(query) assert result.equals(expected) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_is_select_only( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -234,7 +263,7 @@ def test_is_select_only( ] query = "select * from locations" - result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) assert result mock_cursor.fetchall.side_effect = [ @@ -253,13 +282,16 @@ def test_is_select_only( ], ] - result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) assert not result - assert "User does not have `Select` privilege on the table" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_column_after_from( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -273,13 +305,15 @@ def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: S query = "SELECT id, extract(year from date_col) as year FROM test_table" expected = pd.DataFrame(data={"id": [1233], "year": [1998]}) - result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) mock_cursor.execute.assert_called_once_with(query) assert result.equals(expected) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -314,7 +348,7 @@ def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetrie "table": "| City | State |\n|:--------|:---------|\n| Chicago | Illinois |", } - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["dataframe"].equals(expected["dataframe"]) assert result["table"] == expected["table"] @@ -327,7 +361,7 @@ def mock_chat_completion(self) -> Generator: with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletion( id="foo", - model="gpt-4o", + model="gpt-4o-mini", object="chat.completion", choices=[ Choice( @@ -344,9 +378,14 @@ def mock_chat_completion(self) -> Generator: mock_chat_completion_create.return_value = completion yield mock_chat_completion_create - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_run_pipeline( - self, mock_connect: MagicMock, mock_chat_completion: MagicMock, snowflake_retriever: SnowflakeRetriever + self, + mock_connect: MagicMock, + mock_chat_completion: MagicMock, + snowflake_table_retriever: SnowflakeTableRetriever, ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -385,7 +424,7 @@ def test_run_pipeline( pipeline.add_component("llm", llm) pipeline.add_component("adapter", adapter) - pipeline.add_component("snowflake", snowflake_retriever) + pipeline.add_component("snowflake", snowflake_table_retriever) pipeline.connect(sender="llm.replies", receiver="adapter.replies") pipeline.connect(sender="adapter.output", receiver="snowflake.query") @@ -398,7 +437,8 @@ def test_run_pipeline( def test_from_dict(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") data = { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever" + ".SnowflakeTableRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -413,7 +453,7 @@ def test_from_dict(self, monkeypatch: MagicMock) -> None: "login_timeout": 3, }, } - component = SnowflakeRetriever.from_dict(data) + component = SnowflakeTableRetriever.from_dict(data) assert component.user == "test_user" assert component.account == "new_account" @@ -425,7 +465,7 @@ def test_from_dict(self, monkeypatch: MagicMock) -> None: def test_to_dict_default(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") - component = SnowflakeRetriever( + component = SnowflakeTableRetriever( user="test_user", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), account="test_account", @@ -438,7 +478,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -457,7 +497,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") - component = SnowflakeRetriever( + component = SnowflakeTableRetriever( user="John", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), account="TGMD-EEREW", @@ -470,7 +510,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -486,8 +526,12 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: }, } - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_has_select_privilege(self, mock_logger: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_has_select_privilege( + self, mock_logger: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: # Define test cases test_cases = [ # Test case 1: Fully qualified table name in query @@ -523,15 +567,17 @@ def test_has_select_privilege(self, mock_logger: MagicMock, snowflake_retriever: ] for case in test_cases: - result = snowflake_retriever._has_select_privilege( + result = snowflake_table_retriever._has_select_privilege( privileges=case["privileges"], # type: ignore table_name=case["table_name"], # type: ignore ) assert result == case["expected_result"] # type: ignore - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_user_does_not_exist( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_conn = MagicMock() mock_connect.return_value = mock_conn @@ -539,13 +585,27 @@ def test_user_does_not_exist( mock_cursor = mock_conn.cursor.return_value mock_cursor.fetchall.return_value = [] - result = snowflake_retriever._fetch_data(query="""SELECT * FROM test_table""") + result = snowflake_table_retriever._fetch_data(query="""SELECT * FROM test_table""") assert result.empty assert "User does not exist" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_empty_query(self, snowflake_retriever: SnowflakeRetriever) -> None: - result = snowflake_retriever._fetch_data(query="") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + result = snowflake_table_retriever._fetch_data(query="") assert result.empty + + def test_serialization_deserialization_pipeline(self) -> None: + + pipeline = Pipeline() + pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) + pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}")) + pipeline.connect("snow.table", "prompt_builder.table") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline