From ec18b58faa5e49e2ca831fd6c9decf2139f633c6 Mon Sep 17 00:00:00 2001 From: vatsrahul1001 Date: Mon, 22 Jan 2024 22:06:11 +0530 Subject: [PATCH] deprecate snowflakeApiOperator --- .../snowflake/operators/snowflake.py | 194 ++---------------- .../snowflake/triggers/snowflake_trigger.py | 19 +- tests/snowflake/operators/test_snowflake.py | 152 +------------- 3 files changed, 28 insertions(+), 337 deletions(-) diff --git a/astronomer/providers/snowflake/operators/snowflake.py b/astronomer/providers/snowflake/operators/snowflake.py index e13665188..7a60282af 100644 --- a/astronomer/providers/snowflake/operators/snowflake.py +++ b/astronomer/providers/snowflake/operators/snowflake.py @@ -2,11 +2,10 @@ import logging import typing +import warnings from contextlib import closing -from datetime import timedelta from typing import Any, Callable, List -import requests from airflow.exceptions import AirflowException from snowflake.connector import SnowflakeConnection @@ -26,11 +25,7 @@ SnowflakeHookAsync, fetch_all_snowflake_handler, ) -from astronomer.providers.snowflake.hooks.snowflake_sql_api import ( - SnowflakeSqlApiHookAsync, -) from astronomer.providers.snowflake.triggers.snowflake_trigger import ( - SnowflakeSqlApiTrigger, SnowflakeTrigger, get_db_hook, ) @@ -226,181 +221,18 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] | class SnowflakeSqlApiOperatorAsync(SnowflakeOperator): """ - Implemented Async Snowflake SQL API Operator to support multiple SQL statements sequentially, - which is the behavior of the SnowflakeOperator, the Snowflake SQL API allows submitting - multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL - statements for execution, poll to check the status of the execution of a statement. Fetch query results - concurrently. - This Operator currently uses key pair authentication, so you need tp provide private key raw content or - private key file path in the snowflake connection along with other details - - .. seealso:: - - `Snowflake SQL API key pair Authentication `_ - - Where can this operator fit in? - - To execute multiple SQL statements in a single request - - To execute the SQL statement asynchronously and to execute standard queries and most DDL and DML statements - - To develop custom applications and integrations that perform queries - - To create provision users and roles, create table, etc. - - The following commands are not supported: - - The PUT command (in Snowflake SQL) - - The GET command (in Snowflake SQL) - - The CALL command with stored procedures that return a table(stored procedures with the RETURNS TABLE clause). - - .. seealso:: - - - `Snowflake SQL API `_ - - `API Reference `_ - - `Limitation on snowflake SQL API `_ - - :param snowflake_conn_id: Reference to Snowflake connection id - :param sql: the sql code to be executed. (templated) - :param autocommit: if True, each command is automatically committed. - (default value: True) - :param parameters: (optional) the parameters to render the SQL query with. - :param warehouse: name of warehouse (will overwrite any warehouse - defined in the connection's extra JSON) - :param database: name of database (will overwrite database defined - in connection) - :param schema: name of schema (will overwrite schema defined in - connection) - :param role: name of role (will overwrite any role defined in - connection's extra JSON) - :param authenticator: authenticator for Snowflake. - 'snowflake' (default) to use the internal Snowflake authenticator - 'externalbrowser' to authenticate using your web browser and - Okta, ADFS or any other SAML 2.0-compliant identify provider - (IdP) that has been defined for your account - 'https://.okta.com' to authenticate - through native Okta. - :param session_parameters: You can set session-level parameters at - the time you connect to Snowflake - :param poll_interval: the interval in seconds to poll the query - :param statement_count: Number of SQL statement to be executed - :param token_life_time: lifetime of the JWT Token - :param token_renewal_delta: Renewal time of the JWT Token - :param bindings: (Optional) Values of bind variables in the SQL statement. - When executing the statement, Snowflake replaces placeholders (? and :name) in - the statement with these specified values. - """ # noqa - - LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime - RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes - - def __init__( - self, - *, - snowflake_conn_id: str = "snowflake_default", - warehouse: str | None = None, - database: str | None = None, - role: str | None = None, - schema: str | None = None, - authenticator: str | None = None, - session_parameters: dict[str, Any] | None = None, - poll_interval: int = 5, - statement_count: int = 0, - token_life_time: timedelta = LIFETIME, - token_renewal_delta: timedelta = RENEWAL_DELTA, - bindings: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - self.warehouse = warehouse - self.database = database - self.role = role - self.schema = schema - self.authenticator = authenticator - self.session_parameters = session_parameters - self.snowflake_conn_id = snowflake_conn_id - self.poll_interval = poll_interval - self.statement_count = statement_count - self.token_life_time = token_life_time - self.token_renewal_delta = token_renewal_delta - self.bindings = bindings - self.execute_async = False - if self.__class__.__base__.__name__ != "SnowflakeOperator": # type: ignore[union-attr] - # It's better to do str check of the parent class name because currently SnowflakeOperator - # is deprecated and in future OSS SnowflakeOperator may be removed - if any( - [warehouse, database, role, schema, authenticator, session_parameters] - ): # pragma: no cover - hook_params = kwargs.pop("hook_params", {}) # pragma: no cover - kwargs["hook_params"] = { - "warehouse": warehouse, - "database": database, - "role": role, - "schema": schema, - "authenticator": authenticator, - "session_parameters": session_parameters, - **hook_params, - } - super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover - else: - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """ - Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids. - By deferring the SnowflakeSqlApiTrigger class passed along with query ids. - """ - self.log.info("Executing: %s", self.sql) - hook = SnowflakeSqlApiHookAsync( - snowflake_conn_id=self.snowflake_conn_id, - token_life_time=self.token_life_time, - token_renewal_delta=self.token_renewal_delta, - ) - hook.execute_query(self.sql, statement_count=self.statement_count, bindings=self.bindings) - self.query_ids = hook.query_ids - self.log.info("List of query ids %s", self.query_ids) - - if self.do_xcom_push: - context["ti"].xcom_push(key="query_ids", value=self.query_ids) - - succeeded_query_ids = [] - for query_id in self.query_ids: - self.log.info("Retrieving status for query id %s", query_id) - header, params, url = hook.get_request_url_header_params(query_id) - with requests.session() as session: - session.headers = header - with session.get(url, params=params) as resp: - event = hook.process_query_status_response(resp.json(), resp.status_code) - if resp.status_code == 202: - break - elif resp.status_code == 200: - succeeded_query_ids.append(query_id) - else: - raise AirflowException(f"{event['status']}: {event['message']}") - - if len(self.query_ids) == len(succeeded_query_ids): - self.log.info("%s completed successfully.", self.task_id) - return + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator` instead. + """ - self.defer( - timeout=self.execution_timeout, - trigger=SnowflakeSqlApiTrigger( - poll_interval=self.poll_interval, - query_ids=self.query_ids, - snowflake_conn_id=self.snowflake_conn_id, - token_life_time=self.token_life_time, - token_renewal_delta=self.token_renewal_delta, + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator` " + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - raise AirflowException(f"{event['status']}: {event['message']}") - elif "status" in event and event["status"] == "success": - hook = SnowflakeSqlApiHookAsync(snowflake_conn_id=self.snowflake_conn_id) - query_ids = typing.cast(List[str], event["statement_query_ids"]) - hook.check_query_output(query_ids) - self.log.info("%s completed successfully.", self.task_id) - else: - self.log.info("%s completed successfully.", self.task_id) + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/snowflake/triggers/snowflake_trigger.py b/astronomer/providers/snowflake/triggers/snowflake_trigger.py index 376352a22..8de9e816a 100644 --- a/astronomer/providers/snowflake/triggers/snowflake_trigger.py +++ b/astronomer/providers/snowflake/triggers/snowflake_trigger.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import warnings from datetime import timedelta from typing import Any, AsyncIterator @@ -81,14 +82,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: class SnowflakeSqlApiTrigger(BaseTrigger): """ - SnowflakeSqlApi Trigger inherits from the BaseTrigger,it is fired as - deferred class with params to run the task in trigger worker and - fetch the status for the query ids passed - - :param task_id: Reference to task id of the Dag - :param poll_interval: polling period in seconds to check for the status - :param query_ids: List of Query ids to run and poll for the status - :param snowflake_conn_id: Reference to Snowflake connection id + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger` instead. """ def __init__( @@ -99,6 +94,14 @@ def __init__( token_life_time: timedelta, token_renewal_delta: timedelta, ): + warnings.warn( + ( + "This class is deprecated and will be removed in 2.0.0." + "Use :class: `~airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger` instead" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.poll_interval = poll_interval self.query_ids = query_ids diff --git a/tests/snowflake/operators/test_snowflake.py b/tests/snowflake/operators/test_snowflake.py index 86c138097..3749936ef 100644 --- a/tests/snowflake/operators/test_snowflake.py +++ b/tests/snowflake/operators/test_snowflake.py @@ -14,7 +14,6 @@ _check_queries_finish, ) from astronomer.providers.snowflake.triggers.snowflake_trigger import ( - SnowflakeSqlApiTrigger, SnowflakeTrigger, ) from tests.utils.airflow_util import create_context @@ -158,156 +157,13 @@ class TestSnowflakeSqlApiOperatorAsync: "mock_sql, statement_count, query_ids", [(SQL_MULTIPLE_STMTS, 4, (1, 2, 3, 4)), (SINGLE_STMT, 1, (5,))], ) - @mock.patch("astronomer.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperatorAsync.defer") - @mock.patch("requests.session") - @mock.patch("astronomer.providers.snowflake.operators.snowflake.SnowflakeSqlApiHookAsync") - def test_snowflake_sql_api_execute_operator_async_succeeded_before_defer( - self, - mock_hook, - mock_session, - mock_defer, - mock_sql, - statement_count, - query_ids, - ): - """ - Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be fired - when the SnowflakeSqlApiOperatorAsync is executed. - """ - mock_hook.return_value = MagicMock() - mock_hook.return_value.query_ids = query_ids - mock_hook.return_value.get_request_url_header_params.return_value = ("", "", "") - - mock_session.return_value.__enter__.return_value.get.return_value.__enter__.return_value.status_code = ( - 200 - ) - - operator = SnowflakeSqlApiOperatorAsync( + def test_init(self, mock_sql, statement_count, query_ids): + task = SnowflakeSqlApiOperatorAsync( task_id=TASK_ID, snowflake_conn_id=CONN_ID, sql=mock_sql, statement_count=statement_count, ) - operator.execute(create_context(operator)) - - assert not mock_defer.called - - @pytest.mark.parametrize( - "mock_sql, statement_count, query_ids", - [(SQL_MULTIPLE_STMTS, 4, (1, 2, 3, 4)), (SINGLE_STMT, 1, (5,))], - ) - @mock.patch("astronomer.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperatorAsync.defer") - @mock.patch("requests.session") - @mock.patch("astronomer.providers.snowflake.operators.snowflake.SnowflakeSqlApiHookAsync") - def test_snowflake_sql_api_execute_operator_async_failed_before_defer( - self, - mock_hook, - mock_session, - mock_defer, - mock_sql, - statement_count, - query_ids, - ): - """ - Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be fired - when the SnowflakeSqlApiOperatorAsync is executed. - """ - mock_hook.return_value = MagicMock() - mock_hook.return_value.query_ids = query_ids - mock_hook.return_value.get_request_url_header_params.return_value = ("", "", "") - - mock_session.return_value.__enter__.return_value.get.return_value.__enter__.return_value.status_code = ( - 422 - ) - - operator = SnowflakeSqlApiOperatorAsync( - task_id=TASK_ID, - snowflake_conn_id=CONN_ID, - sql=mock_sql, - statement_count=statement_count, - ) - - with pytest.raises(AirflowException): - operator.execute(create_context(operator)) - - assert not mock_defer.called - - @pytest.mark.parametrize( - "mock_sql, statement_count, query_ids", - [(SQL_MULTIPLE_STMTS, 4, (1, 2, 3, 4)), (SINGLE_STMT, 1, (5,))], - ) - @mock.patch("requests.session") - @mock.patch("astronomer.providers.snowflake.operators.snowflake.SnowflakeSqlApiHookAsync") - def test_snowflake_sql_api_execute_operator_async( - self, - mock_hook, - mock_session, - mock_sql, - statement_count, - query_ids, - ): - """ - Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be fired - when the SnowflakeSqlApiOperatorAsync is executed. - """ - mock_hook.return_value = MagicMock() - mock_hook.return_value.query_ids = query_ids - mock_hook.return_value.get_request_url_header_params.return_value = ("", "", "") - - mock_session.return_value.__enter__.return_value.get.return_value.__enter__.return_value.status_code = ( - 202 - ) - - operator = SnowflakeSqlApiOperatorAsync( - task_id=TASK_ID, - snowflake_conn_id=CONN_ID, - sql=mock_sql, - statement_count=statement_count, - ) - - with pytest.raises(TaskDeferred) as exc: - operator.execute(create_context(operator)) - - assert isinstance( - exc.value.trigger, SnowflakeSqlApiTrigger - ), "Trigger is not a SnowflakeSqlApiTrigger" - - def test_snowflake_sql_api_execute_complete_failure(self): - """Test SnowflakeSqlApiOperatorAsync raise AirflowException of error event""" - - operator = SnowflakeSqlApiOperatorAsync( - task_id=TASK_ID, - snowflake_conn_id=CONN_ID, - sql=SQL_MULTIPLE_STMTS, - statement_count=4, - ) - with pytest.raises(AirflowException): - operator.execute_complete( - context=None, - event={"status": "error", "message": "Test failure message", "type": "FAILED_WITH_ERROR"}, - ) - - @pytest.mark.parametrize( - "mock_event", - [ - None, - ({"status": "success", "statement_query_ids": ["uuid", "uuid"]}), - ], - ) - @mock.patch( - "astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync.check_query_output" - ) - def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event): - """Tests execute_complete assert with successful message""" - - operator = SnowflakeSqlApiOperatorAsync( - task_id=TASK_ID, - snowflake_conn_id=CONN_ID, - sql=SQL_MULTIPLE_STMTS, - statement_count=4, - ) - - with mock.patch.object(operator.log, "info") as mock_log_info: - operator.execute_complete(context=None, event=mock_event) - mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) + assert isinstance(task, SnowflakeSqlApiOperatorAsync) + assert task.deferrable is True