diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index aa6f5ff4d..4089b88f6 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -73,6 +73,7 @@ def __init__( db_schema: Optional[str] = None, warehouse: Optional[str] = None, login_timeout: Optional[int] = None, + application_name: Optional[str] = None, ) -> None: """ :param user: User's login. @@ -82,6 +83,7 @@ def __init__( :param db_schema: Name of the schema to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + :param application_name: Name of the application to use when connecting to Snowflake. """ self.user = user @@ -91,6 +93,7 @@ def __init__( self.db_schema = db_schema self.warehouse = warehouse self.login_timeout = login_timeout or 60 + self.application_name = application_name def to_dict(self) -> Dict[str, Any]: """ @@ -108,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]: db_schema=self.db_schema, warehouse=self.warehouse, login_timeout=self.login_timeout, + application_name=self.application_name, ) @classmethod @@ -285,6 +289,7 @@ def _fetch_data( "schema": self.db_schema, "warehouse": self.warehouse, "login_timeout": self.login_timeout, + **({"application": self.application_name} if self.application_name else {}), } ) if conn is None: diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index f5b8fee37..3e6e7d547 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -352,6 +352,64 @@ def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: Snowflake assert result["dataframe"].equals(expected["dataframe"]) assert result["table"] == expected["table"] + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + password="test-api-key", + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_with_application_name( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + snowflake_table_retriever.application_name = "test_application" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "locations", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.description = [mock_col1, mock_col2] + + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM locations" + + snowflake_table_retriever.run(query=query) + + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + password="test-api-key", + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + application="test_application", + ) @pytest.fixture def mock_chat_completion(self) -> Generator: @@ -494,6 +552,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: "db_schema": "test_schema", "warehouse": "test_warehouse", "login_timeout": 30, + "application_name": None, }, } @@ -508,6 +567,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: db_schema="SMALL_TOWNS", warehouse="COMPUTE_WH", login_timeout=30, + application_name="test_application", ) data = component.to_dict() @@ -529,6 +589,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: "db_schema": "SMALL_TOWNS", "warehouse": "COMPUTE_WH", "login_timeout": 30, + "application_name": "test_application", }, } @@ -605,7 +666,6 @@ def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) - assert result.empty def test_serialization_deserialization_pipeline(self) -> None: - pipeline = Pipeline() pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}"))