Skip to content

Commit

Permalink
Return common data structure in DBApi derived classes
Browse files Browse the repository at this point in the history
The ADR for Airflow' s DB API specifies it needs to return a 
named tuple SerializableRow or a list of them.
  • Loading branch information
joffreybienvenu-infrabel authored Dec 22, 2023
1 parent 33ee0b9 commit 5fe5d31
Show file tree
Hide file tree
Showing 22 changed files with 191 additions and 108 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ repos:
files: ^dev/breeze/src/airflow_breeze/utils/docker_command_utils\.py$|^scripts/ci/docker_compose/local\.yml$
pass_filenames: false
additional_dependencies: ['rich>=12.4.4']
- id: check-common-sql-dependency-make-serializable
name: Check dependency of SQL Providers with '_make_serializable'
- id: check-sql-dependency-common-data-structure
name: Check dependency of SQL Providers with common data structure
entry: ./scripts/ci/pre_commit/pre_commit_check_common_sql_dependency.py
language: python
files: ^airflow/providers/.*/hooks/.*\.py$
Expand Down
4 changes: 2 additions & 2 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-cncf-k8s-only-for-executors | Check cncf.kubernetes imports used for executors only | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-common-sql-dependency-make-serializable | Check dependency of SQL Providers with '_make_serializable' | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-core-deprecation-classes | Verify usage of Airflow deprecation classes in core | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | |
Expand Down Expand Up @@ -240,6 +238,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-setup-order | Check order of dependencies in setup.cfg and setup.py | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-sql-dependency-common-data-structure | Check dependency of SQL Providers with common data structure | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-start-date-not-used-in-defaults | start_date not to be defined in default_args in example_dags | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-system-tests-present | Check if system tests have required segments of code | |
Expand Down
49 changes: 32 additions & 17 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import contextlib
import warnings
from contextlib import closing
from datetime import datetime
from typing import (
Expand All @@ -24,6 +26,7 @@
Callable,
Generator,
Iterable,
List,
Mapping,
Protocol,
Sequence,
Expand All @@ -36,7 +39,7 @@
import sqlparse
from sqlalchemy import create_engine

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,10 +125,10 @@ class DbApiHook(BaseHook):
"""
Abstract base class for sql hooks.
When subclassing, maintainers can override the `_make_serializable` method:
When subclassing, maintainers can override the `_make_common_data_structure` method:
This method transforms the result of the handler method (typically `cursor.fetchall()`) into
JSON-serializable objects. Most of the time, the underlying SQL library already returns tuples from
its cursor, and the `_make_serializable` method can be ignored.
objects common across all Hooks derived from this class (tuples). Most of the time, the underlying SQL
library already returns tuples from its cursor, and the `_make_common_data_structure` method can be ignored.
:param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that
if you change the schema parameter value in the constructor of the derived Hook, such change
Expand Down Expand Up @@ -308,7 +311,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
...

def run(
Expand All @@ -319,7 +322,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> T | list[T] | None:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""Run a command or a list of commands.
Pass a list of SQL statements to the sql parameter to get them to
Expand Down Expand Up @@ -395,7 +398,7 @@ def run(
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = self._make_serializable(handler(cur))
result = self._make_common_data_structure(handler(cur))
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
Expand All @@ -415,19 +418,31 @@ def run(
else:
return results

@staticmethod
def _make_serializable(result: Any) -> Any:
"""Ensure the data returned from an SQL command is JSON-serializable.
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple | list[tuple]:
"""Ensure the data returned from an SQL command is a standard tuple or list[tuple].
This method is intended to be overridden by subclasses of the `DbApiHook`. Its purpose is to
transform the result of an SQL command (typically returned by cursor methods) into a
JSON-serializable format.
transform the result of an SQL command (typically returned by cursor methods) into a common
data structure (a tuple or list[tuple]) across all DBApiHook derived Hooks, as defined in the
ADR-0002 of the sql provider.
If this method is not overridden, the result data is returned as-is. If the output of the cursor
is already a common data structure, this method should be ignored.
"""
# Back-compatibility call for providers implementing old ´_make_serializable' method.
with contextlib.suppress(AttributeError):
result = self._make_serializable(result=result) # type: ignore[attr-defined]
warnings.warn(
"The `_make_serializable` method is deprecated and support will be removed in a future "
f"version of the common.sql provider. Please update the {self.__class__.__name__}'s provider "
"to a version based on common.sql >= 1.9.1.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

If this method is not overridden, the result data is returned as-is.
If the output of the cursor is already JSON-serializable, this method
should be ignored.
"""
return result
if isinstance(result, Sequence):
return cast(List[tuple], result)
return cast(tuple, result)

def _run_command(self, cur, sql_statement, parameters):
"""Run a statement using an already open cursor."""
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ description: |
suspended: false
source-date-epoch: 1701983370
versions:
- 1.9.1
- 1.9.0
- 1.8.1
- 1.8.0
Expand Down
65 changes: 52 additions & 13 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,32 @@
# under the License.
from __future__ import annotations

import warnings
from collections import namedtuple
from contextlib import closing
from copy import copy
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
Mapping,
Sequence,
TypeVar,
cast,
overload,
)

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.types import Row

LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")

Expand All @@ -52,6 +65,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
on every request
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
:param return_tuple: Return a ``namedtuple`` object instead of a ``databricks.sql.Row`` object. Default
to False. In a future release of the provider, this will become True by default. This parameter
ensures backward-compatibility during the transition phase to common tuple objects for all hooks based
on DbApiHook. This flag will also be removed in a future release.
:param kwargs: Additional parameters internal to Databricks SQL Connector parameters
"""

Expand All @@ -68,6 +85,7 @@ def __init__(
catalog: str | None = None,
schema: str | None = None,
caller: str = "DatabricksSqlHook",
return_tuple: bool = False,
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
Expand All @@ -80,8 +98,18 @@ def __init__(
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
self.return_tuple = return_tuple
self.additional_params = kwargs

if not self.return_tuple:
warnings.warn(
"""Returning a raw `databricks.sql.Row` object is deprecated. A namedtuple will be
returned instead in a future release of the databricks provider. Set `return_tuple=True` to
enable this behavior.""",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
for arg in ["http_path", "session_configuration", *self.extra_parameters]:
Expand Down Expand Up @@ -167,7 +195,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
...

def run(
Expand All @@ -178,7 +206,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
) -> T | list[T] | None:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""
Run a command or a list of commands.
Expand Down Expand Up @@ -223,7 +251,12 @@ def run(
with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters)
if handler is not None:
result = self._make_serializable(handler(cur))
raw_result = handler(cur)
if self.return_tuple:
result = self._make_common_data_structure(raw_result)
else:
# Returning raw result is deprecated, and do not comply with current common.sql interface
result = raw_result # type: ignore[assignment]
if return_single_query_results(sql, return_last, split_statements):
results = [result]
self.descriptions = [cur.description]
Expand All @@ -241,14 +274,20 @@ def run(
else:
return results

@staticmethod
def _make_serializable(result):
"""Transform the databricks Row objects into JSON-serializable lists."""
def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
return [list(row) for row in result]
elif isinstance(result, Row):
return list(result)
return result
rows: list[Row] = result
rows_fields = rows[0].__fields__
rows_object = namedtuple("Row", rows_fields) # type: ignore[misc]
return cast(List[tuple], [rows_object(*row) for row in rows])
else:
row: Row = result
row_fields = row.__fields__
row_object = namedtuple("Row", row_fields) # type: ignore[misc]
return cast(tuple, row_object(*row))

def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_db_hook(self) -> DatabricksSqlHook:
"catalog": self.catalog,
"schema": self.schema,
"caller": "DatabricksSqlOperator",
"return_tuple": True,
**self.client_parameters,
**self.hook_params,
}
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ versions:

dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-common-sql>=1.8.1
- apache-airflow-providers-common-sql>=1.9.1
- requests>=2.27,<3
# The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and
# it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
...

def run(
Expand All @@ -194,7 +194,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> T | list[T] | None:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""Run a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
Expand Down Expand Up @@ -232,7 +232,7 @@ def run(
with closing(conn.execute(sql_statement, parameters)) as exa_statement:
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if handler is not None:
result = handler(exa_statement)
result = self._make_common_data_structure(handler(exa_statement))
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_columns = self.get_description(exa_statement)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/exasol/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ versions:

dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-common-sql>=1.3.1
- apache-airflow-providers-common-sql>=1.9.1
- pyexasol>=0.5.1
- pandas>=0.17.1

Expand Down
32 changes: 15 additions & 17 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
"""This module contains ODBC hook."""
from __future__ import annotations

from typing import Any, NamedTuple
from typing import Any, List, NamedTuple, Sequence, cast
from urllib.parse import quote_plus

import pyodbc
from pyodbc import Connection, Row, connect

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.helpers import merge_dicts
Expand Down Expand Up @@ -195,9 +195,9 @@ def connect_kwargs(self) -> dict:

return merged_connect_kwargs

def get_conn(self) -> pyodbc.Connection:
def get_conn(self) -> Connection:
"""Returns a pyodbc connection object."""
conn = pyodbc.connect(self.odbc_connection_string, **self.connect_kwargs)
conn = connect(self.odbc_connection_string, **self.connect_kwargs)
return conn

@property
Expand Down Expand Up @@ -228,17 +228,15 @@ def get_sqlalchemy_connection(
cnx = engine.connect(**(connect_kwargs or {}))
return cnx

@staticmethod
def _make_serializable(result: list[pyodbc.Row] | pyodbc.Row | None) -> list[NamedTuple] | None:
"""Transform the pyodbc.Row objects returned from an SQL command into JSON-serializable NamedTuple."""
def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
"""Transform the pyodbc.Row objects returned from an SQL command into typed NamedTuples."""
# Below ignored lines respect NamedTuple docstring, but mypy do not support dynamically
# instantiated Namedtuple, and will never do: https://github.com/python/mypy/issues/848
columns: list[tuple[str, type]] | None = None
if isinstance(result, list):
columns = [col[:2] for col in result[0].cursor_description]
row_object = NamedTuple("Row", columns) # type: ignore[misc]
return [row_object(*row) for row in result]
elif isinstance(result, pyodbc.Row):
columns = [col[:2] for col in result.cursor_description]
return NamedTuple("Row", columns)(*result) # type: ignore[misc, operator]
return result
# instantiated typed Namedtuple, and will never do: https://github.com/python/mypy/issues/848
field_names: list[tuple[str, type]] | None = None
if isinstance(result, Sequence):
field_names = [col[:2] for col in result[0].cursor_description]
row_object = NamedTuple("Row", field_names) # type: ignore[misc]
return cast(List[tuple], [row_object(*row) for row in result])
else:
field_names = [col[:2] for col in result.cursor_description]
return cast(tuple, NamedTuple("Row", field_names)(*result)) # type: ignore[misc, operator]
2 changes: 1 addition & 1 deletion airflow/providers/odbc/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ versions:

dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-common-sql>=1.8.1
- apache-airflow-providers-common-sql>=1.9.1
- pyodbc

integrations:
Expand Down
Loading

0 comments on commit 5fe5d31

Please sign in to comment.