Skip to content

Commit

Permalink
Introduce notion of dialects in DbApiHook (apache#41327)
Browse files Browse the repository at this point in the history
* refactor: Added unit test for handlers module in mssql

* refactor: Added unit test for Dialect class

* refactor: Reformatted unit test of Dialect class

* fix: Added missing import of TYPE_CHECKING

* refactor: Added dialects in provider schema and moved MsSqlDialect to Microsoft mssql provider

* refactor: Removed duplicate handlers and import them from handlers module

* refactor: Fixed import of TYPE_CHECKING in mssql hook

* refactor: Fixed some static checks and imports

* refactor: Dialect should be defined as an array in provider.yaml, not a single element

* refactor: Fixed default dialect name for common sql provider

* refactor: Fixed dialect name for Microsoft MSSQL provider

* refactor: Fixed module for dialect in pyton-modules of common sql provider

* refactor: Dialect module is not part of hooks

* refactor: Moved unit tests for default Dialect to common sql provider instead of Microsoft MSSQL provider

* refactor: Added unit test for MsSqlDialect

* refactor: Reformatted TestMsSqlDialect

* refactor: Implemented dialect resolution using the ProvidersManagers in DbApiHook

* refactor: Updated comment in dialects property

* refactor: Added dialects lists command

* refactor: Removed unused code from _import_hook method

* refactor: Reformatted _discover_provider_dialects method in ProvidersManager

* refactor: Removed unused imports from MsSqlHook

* refactor: Removed dialects from DbApiHook definition

* refactor: Reformatted _discover_provider_dialects

* refactor: Renamed module for TestMsSqlDialect

* refactor: test_generate_replace_sql in TestMsSqlHook should only be tested on Airflow 2.10 or higher

* refactor: Updated expected merge into statement

* refactor: Only run test_generate_replace_sql on TestMsSqlDialect when Airflow is higher than 2.10

* refactor: generate_insert_sql based on dialects should only be tested on Airflow 3.0 or higher

* refactor: Updated reason in skipped tests

* refactor: Removed locking in merge into

* refactor: Added kwargs to constructor of Dialect to make it future proof if additional arguments would be needed in the future

* refactor: Removed row locking clause in generated replace sql statement and removed pyi file for mssql dialect

* refactor: Implemented PostgresDialect

* fix: Fixed constructor Dialect

* refactor: Register PostgresDialect in providers.yaml and added unit test for PostgresDialect

* refactor: PostgresHook now uses the dialect to generate statements and get primary_keys

* refactor: Refactored DbApiHook

* refactor: Refactored the dialect_name mechanism in DbApiHook, override it in specialized Hooks

* refactor: Fixed some static checks

* refactor: Fixed dialect.pyi

* refactor: Refactored how dialects are resolved, if not found always fall back to default

* refactor: Reformatted dialect method in DbApiHook

* refactor: Changed message in raised exception of dialect method when not found

* refactor: Added missing get_records method in Dialect definition

* refactor: Fixed some static checks and mypy issues

* refactor: Raise ValueError if replace_index doesn't exist

* refactor: Increased version of apache-airflow-providers-common-sql to 1.17.1 for mssql and postgres

* refactor: Updated dialect.pyi

* refactor: Updated provider dependencies

* refactor: Incremented version of apache-airflow-providers-common-sql in test_get_install_requirements

* refactor: Reformatted get_records method

* refactor: Common sql provider must depend on latest Airflow version to be able to discover dialects through ProvidersManager

* refactor: Updated provider dependencies

* Revert "refactor: Updated provider dependencies"

This reverts commit 2b591f2.

* Revert "refactor: Common sql provider must depend on latest Airflow version to be able to discover dialects through ProvidersManager"

This reverts commit cb2d043.

* refactor: Added get_dialects method in DbAPiHook which contains fallback code for Airflow 2.8.x so the provider can still be used with Airflow versions prior to 3.0.0

* fix: get_dialects isn't a property but a method

* refactor: Refactored get_dialects in DbAPiHook and added unit tests

* refactor: Added unit tests for MsSqlHook related to dialects

* refactor: Added unit tests for PostgresHook related to dialects

* refactor: Fixed some static checks

* refactor: Removed get_dialects method as this wasn't backward compatible, avoid importing DialectInfo until min required airflow version is 3.0.0 or higher

* refactor: Re-added missing deprecated methods for backward compatibility

* refactor: Added resolve_dialects method in sql.pyi

* refactor: Reorganized imports in sql module

* refactor: Fixed definition of resolve_dialects in sql.pyi

* refactor: Fixed TestDialect

* refactor: Fixed DbAPi tests and moved tests from DbAPi to Odbc

* refactor: Ignore flake8 F811 error as those redefinitions are there for backward compatibility

* refactor: Move import of Dialect under TYPE_CHECKING block

* refactor: Fixed TestMsSqlDialect

* refactor: Fixed TestPostgresDialect

* refactor: Reformatted MsSqlHook

* refactor: Added docstring on placeholder property

* refactor: If no dialect is found for given dialect name, then return default Dialect

* refactor: Try ignoring flake8 F811 error as those redefinitions are there for backward compatibility

* refactor: Moved Dialect out of TYPE_CHECKING block

* fix: Fixed definition location of dialect in dialect.pyi

* fix: Fixed TestTeradataHook

* refactor: Marked test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_code as db test

* refactor: Removed handler definitions from sql.pyi

* Revert "refactor: Removed handler definitions from sql.pyi"

This reverts commit a93d73c.

* refactor: Removed white line

* refactor: Removed duplicate imports if handlers

* refactor: Fixed some static checks

* refactor: Changed logging level of generated sql statement to INFO in DbApiHook

* Revert "refactor: Changed logging level of generated sql statement to INFO in DbApiHook"

This reverts commit c30feaf.

* fix: Moved dialects to correct providers location

* fix: Deleted old providers location

* fix: Re-added missing dialects for mssql and postgres

* fix: Fixed 2 imports for providers tests

* refactored: Reorganized some imports

* refactored: Fixed dialect and sql types

* refactored: Fixed import of test_utils in test_dag_run

* refactored: Added white line in imports of test_dag_run

* refactored: Escape reserved words as column names

* refactored: Fixed initialisation of Dialects

* refactored: Renamed escape_reserved_word to escape_column_name

* refactored: Reformatted TestMsSqlDialect

* refactored: Fixed constructor definition Dialect

* refactored: Fixed TestDbApiHook

* refactored: Removed get_reserved_words from dialect definition

* refactored: Added logging in get_reserved_words method

* refactor: Removed duplicate reserved_words property in DbApiHook

* refactor: Fixed invocation of reserved_words property and changed name of postgres dialect to postgresql like in sqlalchemy

* refactor: Removed override generate_insert_sql method in PostgresDialect as it doesn't do anything different than the existing one in Dialect

* refactor: Added unit tests for _generate_insert_sql methods on MsSqlHook and PostgresHook

* refactor: Reformatted test mssql and test postgres

* refactor: Fixed TestPostgresDialect

* refactor: Refactored get_reserved_words

* refactor: Added escape column name format so we can customize it if needed

* refactor: Suppress NoSuchModuleError exception when trying to load dialect from sqlalchemy to get reserved words

* refactor: Removed name from Dialect and added unit test for dialect name in JdbcHook

* refactor: Fixed parameters in get_column_names method of Dialect

* refactor: Added missing optional schema parameter to get_primary_keys method of MsSqlDialect

* refactor: Fixed TestDialect

* refactor: Fixed TestDbApiHook

* refactor: Fixed TestMsSqlDialect

* refactor: Reformatted test_generate_replace_sql

* refactor: Fixed dialect in MsSqlHook and PostgresHook

* refactor: Fixed TestPostgresDialect

* refactor: Mark TestMySqlHook as a db test

* refactor: Fixed test_generate_insert_sql_with_already_escaped_column_name in TestPostgresHook

* refactor: Reactivate postgres backend in TestPostgresHook

* refactor: Removed name param of constructor in Dialect definition

* refactor: Reformatted imports for TestMySqlHook

* refactor: Fixed import of get_provider_min_airflow_version in test sql

* refactor: Override default escape_column_name_format for MySqlHook

* refactor: Fixed tests in TestMySqlHook

* refactor: Refactored INSERT_SQL_STATEMENT constant in TestMySqlHook

* refactor: When using ODBC, we should also use the odbc connection when creating an sqlalchemy engine

* refactor: Added get_target_fields in Dialect which only returns insertable column_names and added core.dbapihook_resolve_target_fields configuration parameter to allow to specify if we want to resolve target_fields automatically or not

* refactor: By default the core.dbapihook_resolve_target_fields configuration parameter should be False so the original behaviour is respected

* refactor: Added logging statement for target_fields in Dialect

* refactor: Moved _resolve_target_fields as static field of DbApiHook and fixed TestMsSqlHook

* refactor: Added test for get_sqlalchemy_engine in OdbcHook

* refactor: Reformatted teardown method

* Revert "refactor: Added test for get_sqlalchemy_engine in OdbcHook"

This reverts commit 871e96b.

* refactor: Remove patched get_sql_alchemy method in OdbcHook, will fix this in dedicated PR

* refactor: Removed quotes from schema and table_name before invoking sqlalchemy inspector methods

* refactor: Removed check in test_sql for Airflow 2.8 plus as it is already at that min required version

* refactor: Fixed get_primary_keys method in PostgresDialect

* refactor: Reformatted get_primary_keys method of PostgresDialect

* refactor: extract_schema_from_table is now a public classmethod of Dialect

* fix: extract_schema_from_table is now a public classmethod of Dialect

* refactor: Reorganized imports

* refactor: Reorganized imports dialect and postgres

* refactor: Fixed test_dialect in TestProviderManager

* refactor: Removed operators section from provider.yaml in mssql and postgres

* refactor: Removed unused imports in postgres hook

* refactor: Added missing import for AirflowProviderDeprecationWarning

* refactor: Added rowlock option in merge into statement for MSSQL

* refactor: Updated expected replace statement for MSSQL

---------

Co-authored-by: David Blain <[email protected]>
Co-authored-by: David Blain <[email protected]>
  • Loading branch information
3 people authored Dec 31, 2024
1 parent 6eab1f2 commit 87c55b5
Show file tree
Hide file tree
Showing 31 changed files with 1,317 additions and 231 deletions.
16 changes: 16 additions & 0 deletions providers/src/airflow/providers/common/sql/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
190 changes: 190 additions & 0 deletions providers/src/airflow/providers/common/sql/dialects/dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import re
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from methodtools import lru_cache

from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from sqlalchemy.engine import Inspector

T = TypeVar("T")


class Dialect(LoggingMixin):
"""Generic dialect implementation."""

pattern = re.compile(r'"([a-zA-Z0-9_]+)"')

def __init__(self, hook, **kwargs) -> None:
super().__init__(**kwargs)

from airflow.providers.common.sql.hooks.sql import DbApiHook

if not isinstance(hook, DbApiHook):
raise TypeError(f"hook must be an instance of {DbApiHook.__class__.__name__}")

self.hook: DbApiHook = hook

@classmethod
def remove_quotes(cls, value: str | None) -> str | None:
if value:
return cls.pattern.sub(r"\1", value)

@property
def placeholder(self) -> str:
return self.hook.placeholder

@property
def inspector(self) -> Inspector:
return self.hook.inspector

@property
def _insert_statement_format(self) -> str:
return self.hook._insert_statement_format # type: ignore

@property
def _replace_statement_format(self) -> str:
return self.hook._replace_statement_format # type: ignore

@property
def _escape_column_name_format(self) -> str:
return self.hook._escape_column_name_format # type: ignore

@classmethod
def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]:
parts = table.split(".")
return tuple(parts[::-1]) if len(parts) == 2 else (table, None)

@lru_cache(maxsize=None)
def get_column_names(
self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True
) -> list[str] | None:
if schema is None:
table, schema = self.extract_schema_from_table(table)
column_names = list(
column["name"]
for column in filter(
predicate,
self.inspector.get_columns(
table_name=self.remove_quotes(table),
schema=self.remove_quotes(schema) if schema else None,
),
)
)
self.log.debug("Column names for table '%s': %s", table, column_names)
return column_names

@lru_cache(maxsize=None)
def get_target_fields(self, table: str, schema: str | None = None) -> list[str] | None:
target_fields = self.get_column_names(
table,
schema,
lambda column: not column.get("identity", False) and not column.get("autoincrement", False),
)
self.log.debug("Target fields for table '%s': %s", table, target_fields)
return target_fields

@lru_cache(maxsize=None)
def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None:
if schema is None:
table, schema = self.extract_schema_from_table(table)
primary_keys = self.inspector.get_pk_constraint(
table_name=self.remove_quotes(table),
schema=self.remove_quotes(schema) if schema else None,
).get("constrained_columns", [])
self.log.debug("Primary keys for table '%s': %s", table, primary_keys)
return primary_keys

def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
return self.hook.run(sql, autocommit, parameters, handler, split_statements, return_last)

def get_records(
self,
sql: str | list[str],
parameters: Iterable | Mapping[str, Any] | None = None,
) -> Any:
return self.hook.get_records(sql=sql, parameters=parameters)

@property
def reserved_words(self) -> set[str]:
return self.hook.reserved_words

def escape_column_name(self, column_name: str) -> str:
"""
Escape the column name if it's a reserved word.
:param column_name: Name of the column
:return: The escaped column name if needed
"""
if (
column_name != self._escape_column_name_format.format(column_name)
and column_name.casefold() in self.reserved_words
):
return self._escape_column_name_format.format(column_name)
return column_name

def _joined_placeholders(self, values) -> str:
placeholders = [
self.placeholder,
] * len(values)
return ",".join(placeholders)

def _joined_target_fields(self, target_fields) -> str:
if target_fields:
target_fields = ", ".join(map(self.escape_column_name, target_fields))
return f"({target_fields})"
return ""

def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str:
"""
Generate the INSERT SQL statement.
:param table: Name of the target table
:param values: The row to insert into the table
:param target_fields: The names of the columns to fill in the table
:return: The generated INSERT SQL statement
"""
return self._insert_statement_format.format(
table, self._joined_target_fields(target_fields), self._joined_placeholders(values)
)

def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str:
"""
Generate the REPLACE SQL statement.
:param table: Name of the target table
:param values: The row to insert into the table
:param target_fields: The names of the columns to fill in the table
:return: The generated REPLACE SQL statement
"""
return self._replace_statement_format.format(
table, self._joined_target_fields(target_fields), self._joined_placeholders(values)
)
73 changes: 73 additions & 0 deletions providers/src/airflow/providers/common/sql/dialects/dialect.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# This is automatically generated stub for the `common.sql` provider
#
# This file is generated automatically by the `update-common-sql-api stubs` pre-commit
# and the .pyi file represents part of the "public" API that the
# `common.sql` provider exposes to other providers.
#
# Any, potentially breaking change in the stubs will require deliberate manual action from the contributor
# making a change to the `common.sql` provider. Those stubs are also used by MyPy automatically when checking
# if only public API of the common.sql provider is used by all the other providers.
#
# You can read more in the README_API.md file
#
"""
Definition of the public interface for airflow.providers.common.sql.dialects.dialect
isort:skip_file
"""
from _typeshed import Incomplete as Incomplete
from airflow.utils.log.logging_mixin import LoggingMixin as LoggingMixin
from sqlalchemy.engine import Inspector as Inspector
from typing import Any, Callable, Iterable, Mapping, TypeVar

T = TypeVar("T")

class Dialect(LoggingMixin):
hook: Incomplete
def __init__(self, hook, **kwargs) -> None: ...
@classmethod
def remove_quotes(cls, value: str | None) -> str | None: ...
@property
def placeholder(self) -> str: ...
@property
def inspector(self) -> Inspector: ...
@classmethod
def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]: ...
def get_column_names(
self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = ...
) -> list[str] | None: ...
def get_target_fields(self, table: str, schema: str | None = None) -> list[str] | None: ...
def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None: ...
def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...
def get_records(
self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None
) -> Any: ...
@property
def reserved_words(self) -> set[str]: ...
def escape_column_name(self, column_name: str) -> str: ...
def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str: ...
def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: ...
Loading

0 comments on commit 87c55b5

Please sign in to comment.