diff --git a/providers/src/airflow/providers/common/sql/dialects/__init__.py b/providers/src/airflow/providers/common/sql/dialects/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/src/airflow/providers/common/sql/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/src/airflow/providers/common/sql/dialects/dialect.py b/providers/src/airflow/providers/common/sql/dialects/dialect.py new file mode 100644 index 0000000000000..184e6a5ce4e4c --- /dev/null +++ b/providers/src/airflow/providers/common/sql/dialects/dialect.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import re +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from methodtools import lru_cache + +from airflow.utils.log.logging_mixin import LoggingMixin + +if TYPE_CHECKING: + from sqlalchemy.engine import Inspector + +T = TypeVar("T") + + +class Dialect(LoggingMixin): + """Generic dialect implementation.""" + + pattern = re.compile(r'"([a-zA-Z0-9_]+)"') + + def __init__(self, hook, **kwargs) -> None: + super().__init__(**kwargs) + + from airflow.providers.common.sql.hooks.sql import DbApiHook + + if not isinstance(hook, DbApiHook): + raise TypeError(f"hook must be an instance of {DbApiHook.__class__.__name__}") + + self.hook: DbApiHook = hook + + @classmethod + def remove_quotes(cls, value: str | None) -> str | None: + if value: + return cls.pattern.sub(r"\1", value) + + @property + def placeholder(self) -> str: + return self.hook.placeholder + + @property + def inspector(self) -> Inspector: + return self.hook.inspector + + @property + def _insert_statement_format(self) -> str: + return self.hook._insert_statement_format # type: ignore + + @property + def _replace_statement_format(self) -> str: + return self.hook._replace_statement_format # type: ignore + + @property + def _escape_column_name_format(self) -> str: + return self.hook._escape_column_name_format # type: ignore + + @classmethod + def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]: + parts = table.split(".") + return tuple(parts[::-1]) if len(parts) == 2 else (table, None) + + @lru_cache(maxsize=None) + def get_column_names( + self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True + ) -> list[str] | None: + if schema is None: + table, schema = self.extract_schema_from_table(table) + column_names = list( + column["name"] + for column in filter( + predicate, + self.inspector.get_columns( + table_name=self.remove_quotes(table), + schema=self.remove_quotes(schema) if schema else None, + ), + ) + ) + self.log.debug("Column names for table '%s': %s", table, column_names) + return column_names + + @lru_cache(maxsize=None) + def get_target_fields(self, table: str, schema: str | None = None) -> list[str] | None: + target_fields = self.get_column_names( + table, + schema, + lambda column: not column.get("identity", False) and not column.get("autoincrement", False), + ) + self.log.debug("Target fields for table '%s': %s", table, target_fields) + return target_fields + + @lru_cache(maxsize=None) + def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None: + if schema is None: + table, schema = self.extract_schema_from_table(table) + primary_keys = self.inspector.get_pk_constraint( + table_name=self.remove_quotes(table), + schema=self.remove_quotes(schema) if schema else None, + ).get("constrained_columns", []) + self.log.debug("Primary keys for table '%s': %s", table, primary_keys) + return primary_keys + + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping[str, Any] | None = None, + handler: Callable[[Any], T] | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: + return self.hook.run(sql, autocommit, parameters, handler, split_statements, return_last) + + def get_records( + self, + sql: str | list[str], + parameters: Iterable | Mapping[str, Any] | None = None, + ) -> Any: + return self.hook.get_records(sql=sql, parameters=parameters) + + @property + def reserved_words(self) -> set[str]: + return self.hook.reserved_words + + def escape_column_name(self, column_name: str) -> str: + """ + Escape the column name if it's a reserved word. + + :param column_name: Name of the column + :return: The escaped column name if needed + """ + if ( + column_name != self._escape_column_name_format.format(column_name) + and column_name.casefold() in self.reserved_words + ): + return self._escape_column_name_format.format(column_name) + return column_name + + def _joined_placeholders(self, values) -> str: + placeholders = [ + self.placeholder, + ] * len(values) + return ",".join(placeholders) + + def _joined_target_fields(self, target_fields) -> str: + if target_fields: + target_fields = ", ".join(map(self.escape_column_name, target_fields)) + return f"({target_fields})" + return "" + + def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str: + """ + Generate the INSERT SQL statement. + + :param table: Name of the target table + :param values: The row to insert into the table + :param target_fields: The names of the columns to fill in the table + :return: The generated INSERT SQL statement + """ + return self._insert_statement_format.format( + table, self._joined_target_fields(target_fields), self._joined_placeholders(values) + ) + + def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: + """ + Generate the REPLACE SQL statement. + + :param table: Name of the target table + :param values: The row to insert into the table + :param target_fields: The names of the columns to fill in the table + :return: The generated REPLACE SQL statement + """ + return self._replace_statement_format.format( + table, self._joined_target_fields(target_fields), self._joined_placeholders(values) + ) diff --git a/providers/src/airflow/providers/common/sql/dialects/dialect.pyi b/providers/src/airflow/providers/common/sql/dialects/dialect.pyi new file mode 100644 index 0000000000000..423fab3ccd04d --- /dev/null +++ b/providers/src/airflow/providers/common/sql/dialects/dialect.pyi @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This is automatically generated stub for the `common.sql` provider +# +# This file is generated automatically by the `update-common-sql-api stubs` pre-commit +# and the .pyi file represents part of the "public" API that the +# `common.sql` provider exposes to other providers. +# +# Any, potentially breaking change in the stubs will require deliberate manual action from the contributor +# making a change to the `common.sql` provider. Those stubs are also used by MyPy automatically when checking +# if only public API of the common.sql provider is used by all the other providers. +# +# You can read more in the README_API.md file +# +""" +Definition of the public interface for airflow.providers.common.sql.dialects.dialect +isort:skip_file +""" +from _typeshed import Incomplete as Incomplete +from airflow.utils.log.logging_mixin import LoggingMixin as LoggingMixin +from sqlalchemy.engine import Inspector as Inspector +from typing import Any, Callable, Iterable, Mapping, TypeVar + +T = TypeVar("T") + +class Dialect(LoggingMixin): + hook: Incomplete + def __init__(self, hook, **kwargs) -> None: ... + @classmethod + def remove_quotes(cls, value: str | None) -> str | None: ... + @property + def placeholder(self) -> str: ... + @property + def inspector(self) -> Inspector: ... + @classmethod + def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]: ... + def get_column_names( + self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = ... + ) -> list[str] | None: ... + def get_target_fields(self, table: str, schema: str | None = None) -> list[str] | None: ... + def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None: ... + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping[str, Any] | None = None, + handler: Callable[[Any], T] | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ... + def get_records( + self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None + ) -> Any: ... + @property + def reserved_words(self) -> set[str]: ... + def escape_column_name(self, column_name: str) -> str: ... + def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str: ... + def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: ... diff --git a/providers/src/airflow/providers/common/sql/hooks/sql.py b/providers/src/airflow/providers/common/sql/hooks/sql.py index f4d107f0c5f3e..25d25eaec1786 100644 --- a/providers/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/src/airflow/providers/common/sql/hooks/sql.py @@ -18,8 +18,8 @@ import contextlib import warnings -from collections.abc import Generator, Iterable, Mapping, Sequence -from contextlib import closing, contextmanager +from collections.abc import Generator, Iterable, Mapping, MutableMapping, Sequence +from contextlib import closing, contextmanager, suppress from datetime import datetime from functools import cached_property from typing import ( @@ -34,15 +34,20 @@ from urllib.parse import urlparse import sqlparse +from methodtools import lru_cache from more_itertools import chunked from sqlalchemy import create_engine -from sqlalchemy.engine import Inspector +from sqlalchemy.engine import Inspector, make_url +from sqlalchemy.exc import ArgumentError, NoSuchModuleError +from airflow.configuration import conf from airflow.exceptions import ( AirflowException, AirflowOptionalProviderFeatureException, ) from airflow.hooks.base import BaseHook +from airflow.providers.common.sql.dialects.dialect import Dialect +from airflow.utils.module_loading import import_string if TYPE_CHECKING: from pandas import DataFrame @@ -83,6 +88,36 @@ def fetch_one_handler(cursor) -> list[tuple] | None: return handlers.fetch_one_handler(cursor) +def resolve_dialects() -> MutableMapping[str, MutableMapping]: + from airflow.providers_manager import ProvidersManager + + providers_manager = ProvidersManager() + + # TODO: this check can be removed once common sql provider depends on Airflow 3.0 or higher, + # we could then also use DialectInfo and won't need to convert it to a dict. + if hasattr(providers_manager, "dialects"): + return {key: dict(value._asdict()) for key, value in providers_manager.dialects.items()} + + # TODO: this can be removed once common sql provider depends on Airflow 3.0 or higher + return { + "default": dict( + name="default", + dialect_class_name="airflow.providers.common.sql.dialects.dialect.Dialect", + provider_name="apache-airflow-providers-common-sql", + ), + "mssql": dict( + name="mssql", + dialect_class_name="airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect", + provider_name="apache-airflow-providers-microsoft-mssql", + ), + "postgresql": dict( + name="postgresql", + dialect_class_name="airflow.providers.postgres.dialects.postgres.PostgresDialect", + provider_name="apache-airflow-providers-postgres", + ), + } + + class ConnectorProtocol(Protocol): """Database connection protocol.""" @@ -129,6 +164,8 @@ class DbApiHook(BaseHook): _test_connection_sql = "select 1" # Default SQL placeholder _placeholder: str = "%s" + _dialects: MutableMapping[str, MutableMapping] = resolve_dialects() + _resolve_target_fields = conf.getboolean("core", "dbapihook_resolve_target_fields", fallback=False) def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs): super().__init__() @@ -153,6 +190,7 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa self._replace_statement_format: str = kwargs.get( "replace_statement_format", "REPLACE INTO {} {} VALUES ({})" ) + self._escape_column_name_format: str = kwargs.get("escape_column_name_format", '"{}"') self._connection: Connection | None = kwargs.pop("connection", None) def get_conn_id(self) -> str: @@ -262,6 +300,57 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): def inspector(self) -> Inspector: return Inspector.from_engine(self.get_sqlalchemy_engine()) + @cached_property + def dialect_name(self) -> str: + try: + return make_url(self.get_uri()).get_dialect().name + except (ArgumentError, NoSuchModuleError): + config = self.connection_extra + sqlalchemy_scheme = config.get("sqlalchemy_scheme") + if sqlalchemy_scheme: + return sqlalchemy_scheme.split("+")[0] if "+" in sqlalchemy_scheme else sqlalchemy_scheme + return config.get("dialect", "default") + + @cached_property + def dialect(self) -> Dialect: + from airflow.utils.module_loading import import_string + + dialect_info = self._dialects.get(self.dialect_name) + + self.log.debug("dialect_info: %s", dialect_info) + + if dialect_info: + try: + return import_string(dialect_info["dialect_class_name"])(self) + except ImportError: + raise AirflowOptionalProviderFeatureException( + f"{dialect_info.dialect_class_name} not found, run: pip install " + f"'{dialect_info.provider_name}'." + ) + return Dialect(self) + + @property + def reserved_words(self) -> set[str]: + return self.get_reserved_words(self.dialect_name) + + @lru_cache(maxsize=None) + def get_reserved_words(self, dialect_name: str) -> set[str]: + result = set() + with suppress(ImportError, ModuleNotFoundError, NoSuchModuleError): + dialect_module = import_string(f"sqlalchemy.dialects.{dialect_name}.base") + + if hasattr(dialect_module, "RESERVED_WORDS"): + result = set(dialect_module.RESERVED_WORDS) + else: + dialect_module = import_string(f"sqlalchemy.dialects.{dialect_name}.reserved_words") + reserved_words_attr = f"RESERVED_WORDS_{dialect_name.upper()}" + + if hasattr(dialect_module, reserved_words_attr): + result = set(getattr(dialect_module, reserved_words_attr)) + + self.log.debug("reserved words for '%s': %s", dialect_name, result) + return result + def get_pandas_df( self, sql, @@ -543,7 +632,7 @@ def get_cursor(self) -> Any: """Return a cursor.""" return self.get_conn().cursor() - def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) -> str: + def _generate_insert_sql(self, table, values, target_fields=None, replace: bool = False, **kwargs) -> str: """ Generate the INSERT SQL statement. @@ -551,24 +640,19 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) :param table: Name of the target table :param values: The row to insert into the table - :param target_fields: The names of the columns to fill in the table + :param target_fields: The names of the columns to fill in the table. If no target fields are + specified, they will be determined dynamically from the table's metadata. :param replace: Whether to replace/upsert instead of insert :return: The generated INSERT or REPLACE/UPSERT SQL statement """ - placeholders = [ - self.placeholder, - ] * len(values) - - if target_fields: - target_fields = ", ".join(target_fields) - target_fields = f"({target_fields})" - else: - target_fields = "" + if not target_fields and self._resolve_target_fields: + with suppress(Exception): + target_fields = self.dialect.get_target_fields(table) - if not replace: - return self._insert_statement_format.format(table, target_fields, ",".join(placeholders)) + if replace: + return self.dialect.generate_replace_sql(table, values, target_fields, **kwargs) - return self._replace_statement_format.format(table, target_fields, ",".join(placeholders)) + return self.dialect.generate_insert_sql(table, values, target_fields, **kwargs) @contextmanager def _create_autocommit_connection(self, autocommit: bool = False): diff --git a/providers/src/airflow/providers/common/sql/hooks/sql.pyi b/providers/src/airflow/providers/common/sql/hooks/sql.pyi index ed93958401ed4..afa9754a8b672 100644 --- a/providers/src/airflow/providers/common/sql/hooks/sql.pyi +++ b/providers/src/airflow/providers/common/sql/hooks/sql.pyi @@ -34,19 +34,33 @@ isort:skip_file from _typeshed import Incomplete as Incomplete from airflow.hooks.base import BaseHook as BaseHook from airflow.models import Connection as Connection +from airflow.providers.common.sql.dialects.dialect import Dialect as Dialect from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo from functools import cached_property as cached_property from pandas import DataFrame as DataFrame from sqlalchemy.engine import Inspector as Inspector, URL as URL -from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload +from typing import ( + Any, + Callable, + Generator, + Iterable, + Mapping, + MutableMapping, + Protocol, + Sequence, + TypeVar, + overload, +) T = TypeVar("T") SQL_PLACEHOLDERS: Incomplete +WARNING_MESSAGE: str def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool): ... def fetch_all_handler(cursor) -> list[tuple] | None: ... def fetch_one_handler(cursor) -> list[tuple] | None: ... +def resolve_dialects() -> MutableMapping[str, MutableMapping]: ... class ConnectorProtocol(Protocol): def connect(self, host: str, port: int, username: str, schema: str) -> Any: ... @@ -79,6 +93,13 @@ class DbApiHook(BaseHook): def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None): ... @property def inspector(self) -> Inspector: ... + @cached_property + def dialect_name(self) -> str: ... + @cached_property + def dialect(self) -> Dialect: ... + @property + def reserved_words(self) -> set[str]: ... + def get_reserved_words(self, dialect_name: str) -> set[str]: ... def get_pandas_df( self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs ) -> DataFrame: ... diff --git a/providers/src/airflow/providers/common/sql/provider.yaml b/providers/src/airflow/providers/common/sql/provider.yaml index 32bfbe2d493e6..530cc35188265 100644 --- a/providers/src/airflow/providers/common/sql/provider.yaml +++ b/providers/src/airflow/providers/common/sql/provider.yaml @@ -93,6 +93,10 @@ operators: python-modules: - airflow.providers.common.sql.operators.sql +dialects: + - dialect-type: default + dialect-class-name: airflow.providers.common.sql.dialects.dialect.Dialect + hooks: - integration-name: Common SQL python-modules: diff --git a/providers/src/airflow/providers/microsoft/mssql/dialects/__init__.py b/providers/src/airflow/providers/microsoft/mssql/dialects/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/src/airflow/providers/microsoft/mssql/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/src/airflow/providers/microsoft/mssql/dialects/mssql.py b/providers/src/airflow/providers/microsoft/mssql/dialects/mssql.py new file mode 100644 index 0000000000000..fc2110a762d64 --- /dev/null +++ b/providers/src/airflow/providers/microsoft/mssql/dialects/mssql.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from methodtools import lru_cache + +from airflow.providers.common.sql.dialects.dialect import Dialect +from airflow.providers.common.sql.hooks.handlers import fetch_all_handler + + +class MsSqlDialect(Dialect): + """Microsoft SQL Server dialect implementation.""" + + @lru_cache(maxsize=None) + def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None: + primary_keys = self.run( + f""" + SELECT c.name + FROM sys.columns c + WHERE c.object_id = OBJECT_ID('{table}') + AND EXISTS (SELECT 1 FROM sys.index_columns ic + INNER JOIN sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id + WHERE i.is_primary_key = 1 + AND ic.object_id = c.object_id + AND ic.column_id = c.column_id); + """, + handler=fetch_all_handler, + ) + primary_keys = [pk[0] for pk in primary_keys] if primary_keys else [] # type: ignore + self.log.debug("Primary keys for table '%s': %s", table, primary_keys) + return primary_keys # type: ignore + + def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: + primary_keys = self.get_primary_keys(table) + columns = [ + self.escape_column_name(target_field) + for target_field in target_fields + if target_field in set(target_fields).difference(set(primary_keys)) + ] + + self.log.debug("primary_keys: %s", primary_keys) + self.log.debug("columns: %s", columns) + + return f"""MERGE INTO {table} WITH (ROWLOCK) AS target + USING (SELECT {', '.join(map(lambda column: f'{self.placeholder} AS {column}', target_fields))}) AS source + ON {' AND '.join(map(lambda column: f'target.{self.escape_column_name(column)} = source.{column}', primary_keys))} + WHEN MATCHED THEN + UPDATE SET {', '.join(map(lambda column: f'target.{column} = source.{column}', columns))} + WHEN NOT MATCHED THEN + INSERT ({', '.join(target_fields)}) VALUES ({', '.join(map(lambda column: f'source.{self.escape_column_name(column)}', target_fields))});""" diff --git a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py index a367250ed33c4..089f1ccfb7d3f 100644 --- a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py +++ b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py @@ -19,13 +19,16 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import pymssql -from methodtools import lru_cache from pymssql import Connection as PymssqlConnection -from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect + +if TYPE_CHECKING: + from airflow.providers.common.sql.dialects.dialect import Dialect class MsSqlHook(DbApiHook): @@ -63,6 +66,14 @@ def sqlalchemy_scheme(self) -> str: raise RuntimeError("sqlalchemy_scheme in connection extra should not contain : or / characters") return self._sqlalchemy_scheme or extra_scheme or self.DEFAULT_SQLALCHEMY_SCHEME + @property + def dialect_name(self) -> str: + return "mssql" + + @property + def dialect(self) -> Dialect: + return MsSqlDialect(self) + def get_uri(self) -> str: from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit @@ -84,56 +95,6 @@ def get_sqlalchemy_connection( engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs) return engine.connect(**(connect_kwargs or {})) - @lru_cache(maxsize=None) - def get_primary_keys(self, table: str) -> list[str]: - primary_keys = self.run( - f""" - SELECT c.name - FROM sys.columns c - WHERE c.object_id = OBJECT_ID('{table}') - AND EXISTS (SELECT 1 FROM sys.index_columns ic - INNER JOIN sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id - WHERE i.is_primary_key = 1 - AND ic.object_id = c.object_id - AND ic.column_id = c.column_id); - """, - handler=fetch_all_handler, - ) - return [pk[0] for pk in primary_keys] # type: ignore - - def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) -> str: - """ - Generate the INSERT SQL statement. - - The MERGE INTO variant is specific to MSSQL syntax - - :param table: Name of the target table - :param values: The row to insert into the table - :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace/merge into instead of insert - :return: The generated INSERT or MERGE INTO SQL statement - """ - if not replace: - return super()._generate_insert_sql(table, values, target_fields, replace, **kwargs) # type: ignore - - primary_keys = self.get_primary_keys(table) - columns = [ - target_field - for target_field in target_fields - if target_field in set(target_fields).difference(set(primary_keys)) - ] - - self.log.debug("primary_keys: %s", primary_keys) - self.log.info("columns: %s", columns) - - return f"""MERGE INTO {table} AS target - USING (SELECT {', '.join(map(lambda column: f'{self.placeholder} AS {column}', target_fields))}) AS source - ON {' AND '.join(map(lambda column: f'target.{column} = source.{column}', primary_keys))} - WHEN MATCHED THEN - UPDATE SET {', '.join(map(lambda column: f'target.{column} = source.{column}', columns))} - WHEN NOT MATCHED THEN - INSERT ({', '.join(target_fields)}) VALUES ({', '.join(map(lambda column: f'source.{column}', target_fields))});""" - def get_conn(self) -> PymssqlConnection: """Return ``pymssql`` connection object.""" conn = self.connection diff --git a/providers/src/airflow/providers/microsoft/mssql/provider.yaml b/providers/src/airflow/providers/microsoft/mssql/provider.yaml index b5d0e63479851..d29ee70d8a517 100644 --- a/providers/src/airflow/providers/microsoft/mssql/provider.yaml +++ b/providers/src/airflow/providers/microsoft/mssql/provider.yaml @@ -72,6 +72,10 @@ integrations: - /docs/apache-airflow-providers-microsoft-mssql/operators.rst tags: [software] +dialects: + - dialect-type: mssql + dialect-class-name: airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect + hooks: - integration-name: Microsoft SQL Server (MSSQL) python-modules: diff --git a/providers/src/airflow/providers/mysql/hooks/mysql.py b/providers/src/airflow/providers/mysql/hooks/mysql.py index 5ed8a62d75f23..48185c1cf5516 100644 --- a/providers/src/airflow/providers/mysql/hooks/mysql.py +++ b/providers/src/airflow/providers/mysql/hooks/mysql.py @@ -82,6 +82,7 @@ def __init__(self, *args, **kwargs) -> None: self.schema = kwargs.pop("schema", None) self.local_infile = kwargs.pop("local_infile", False) self.init_command = kwargs.pop("init_command", None) + self._escape_column_name_format: str = kwargs.get("escape_column_name_format", "`{}`") def set_autocommit(self, conn: MySQLConnectionTypes, autocommit: bool) -> None: """ diff --git a/providers/src/airflow/providers/postgres/dialects/__init__.py b/providers/src/airflow/providers/postgres/dialects/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/src/airflow/providers/postgres/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/src/airflow/providers/postgres/dialects/postgres.py b/providers/src/airflow/providers/postgres/dialects/postgres.py new file mode 100644 index 0000000000000..5db4cca18f8e5 --- /dev/null +++ b/providers/src/airflow/providers/postgres/dialects/postgres.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from methodtools import lru_cache + +from airflow.providers.common.sql.dialects.dialect import Dialect + + +class PostgresDialect(Dialect): + """Postgres dialect implementation.""" + + @property + def name(self) -> str: + return "postgresql" + + @lru_cache(maxsize=None) + def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None: + """ + Get the table's primary key. + + :param table: Name of the target table + :param schema: Name of the target schema, public by default + :return: Primary key columns list + """ + if schema is None: + table, schema = self.extract_schema_from_table(table) + sql = """ + select kcu.column_name + from information_schema.table_constraints tco + join information_schema.key_column_usage kcu + on kcu.constraint_name = tco.constraint_name + and kcu.constraint_schema = tco.constraint_schema + and kcu.constraint_name = tco.constraint_name + where tco.constraint_type = 'PRIMARY KEY' + and kcu.table_schema = %s + and kcu.table_name = %s + """ + pk_columns = [ + row[0] for row in self.get_records(sql, (self.remove_quotes(schema), self.remove_quotes(table))) + ] + return pk_columns or None + + def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: + """ + Generate the REPLACE SQL statement. + + :param table: Name of the target table + :param values: The row to insert into the table + :param target_fields: The names of the columns to fill in the table + :param replace: Whether to replace instead of insert + :param replace_index: the column or list of column names to act as + index for the ON CONFLICT clause + :return: The generated INSERT or REPLACE SQL statement + """ + if not target_fields: + raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names") + + replace_index = kwargs.get("replace_index") or self.get_primary_keys(table) + + if not replace_index: + raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index") + + if isinstance(replace_index, str): + replace_index = [replace_index] + + sql = self.generate_insert_sql(table, values, target_fields, **kwargs) + on_conflict_str = f" ON CONFLICT ({', '.join(map(self.escape_column_name, replace_index))})" + replace_target = [self.escape_column_name(f) for f in target_fields if f not in replace_index] + + if replace_target: + replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target) + sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}" + else: + sql += f"{on_conflict_str} DO NOTHING" + + return sql diff --git a/providers/src/airflow/providers/postgres/hooks/postgres.py b/providers/src/airflow/providers/postgres/hooks/postgres.py index 9b657c14416e4..e760ebaab45c0 100644 --- a/providers/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/src/airflow/providers/postgres/hooks/postgres.py @@ -18,7 +18,6 @@ from __future__ import annotations import os -from collections.abc import Iterable from contextlib import closing from copy import deepcopy from typing import TYPE_CHECKING, Any, Union @@ -31,11 +30,13 @@ from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.postgres.dialects.postgres import PostgresDialect if TYPE_CHECKING: from psycopg2.extensions import connection from airflow.models.connection import Connection + from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.openlineage.sqlparser import DatabaseInfo CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor] @@ -123,6 +124,14 @@ def sqlalchemy_url(self) -> URL: query=query, ) + @property + def dialect_name(self) -> str: + return "postgresql" + + @property + def dialect(self) -> Dialect: + return PostgresDialect(self) + def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() cursor_types = { @@ -286,67 +295,7 @@ def get_table_primary_key(self, table: str, schema: str | None = "public") -> li :param schema: Name of the target schema, public by default :return: Primary key columns list """ - sql = """ - select kcu.column_name - from information_schema.table_constraints tco - join information_schema.key_column_usage kcu - on kcu.constraint_name = tco.constraint_name - and kcu.constraint_schema = tco.constraint_schema - and kcu.constraint_name = tco.constraint_name - where tco.constraint_type = 'PRIMARY KEY' - and kcu.table_schema = %s - and kcu.table_name = %s - """ - pk_columns = [row[0] for row in self.get_records(sql, (schema, table))] - return pk_columns or None - - def _generate_insert_sql( - self, table: str, values: tuple[str, ...], target_fields: Iterable[str], replace: bool, **kwargs - ) -> str: - """ - Generate the INSERT SQL statement. - - The REPLACE variant is specific to the PostgreSQL syntax. - - :param table: Name of the target table - :param values: The row to insert into the table - :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace instead of insert - :param replace_index: the column or list of column names to act as - index for the ON CONFLICT clause - :return: The generated INSERT or REPLACE SQL statement - """ - placeholders = [ - self.placeholder, - ] * len(values) - replace_index = kwargs.get("replace_index") - - if target_fields: - target_fields_fragment = ", ".join(target_fields) - target_fields_fragment = f"({target_fields_fragment})" - else: - target_fields_fragment = "" - - sql = f"INSERT INTO {table} {target_fields_fragment} VALUES ({','.join(placeholders)})" - - if replace: - if not target_fields: - raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names") - if not replace_index: - raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index") - if isinstance(replace_index, str): - replace_index = [replace_index] - - on_conflict_str = f" ON CONFLICT ({', '.join(replace_index)})" - replace_target = [f for f in target_fields if f not in replace_index] - - if replace_target: - replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target) - sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}" - else: - sql += f"{on_conflict_str} DO NOTHING" - - return sql + return self.dialect.get_primary_keys(table=table, schema=schema) def get_openlineage_database_info(self, connection) -> DatabaseInfo: """Return Postgres/Redshift specific information for OpenLineage.""" diff --git a/providers/src/airflow/providers/postgres/provider.yaml b/providers/src/airflow/providers/postgres/provider.yaml index 13425797e9c54..3b2b09411940d 100644 --- a/providers/src/airflow/providers/postgres/provider.yaml +++ b/providers/src/airflow/providers/postgres/provider.yaml @@ -86,6 +86,10 @@ integrations: logo: /integration-logos/postgres/Postgres.png tags: [software] +dialects: + - dialect-type: postgresql + dialect-class-name: airflow.providers.postgres.dialects.postgres.PostgresDialect + hooks: - integration-name: PostgreSQL python-modules: diff --git a/providers/tests/common/sql/dialects/__init__.py b/providers/tests/common/sql/dialects/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/tests/common/sql/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/tests/common/sql/dialects/test_dialect.py b/providers/tests/common/sql/dialects/test_dialect.py new file mode 100644 index 0000000000000..1021b1c617caa --- /dev/null +++ b/providers/tests/common/sql/dialects/test_dialect.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +from sqlalchemy.engine import Inspector + +from airflow.providers.common.sql.dialects.dialect import Dialect +from airflow.providers.common.sql.hooks.sql import DbApiHook + + +class TestDialect: + def setup_method(self): + inspector = MagicMock(spc=Inspector) + inspector.get_columns.side_effect = lambda table_name, schema: [ + {"name": "id", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ] + inspector.get_pk_constraint.side_effect = lambda table_name, schema: {"constrained_columns": ["id"]} + self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) + + def test_remove_quotes(self): + assert not Dialect.remove_quotes(None) + assert Dialect.remove_quotes("table") == "table" + assert Dialect.remove_quotes('"table"') == "table" + + def test_placeholder(self): + assert Dialect(self.test_db_hook).placeholder == "?" + + def test_extract_schema_from_table(self): + assert Dialect.extract_schema_from_table("schema.table") == ("table", "schema") + + def test_get_column_names(self): + assert Dialect(self.test_db_hook).get_column_names("table", "schema") == [ + "id", + "name", + "firstname", + "age", + ] + + def test_get_target_fields(self): + assert Dialect(self.test_db_hook).get_target_fields("table", "schema") == [ + "name", + "firstname", + "age", + ] + + def test_get_primary_keys(self): + assert Dialect(self.test_db_hook).get_primary_keys("table", "schema") == ["id"] diff --git a/providers/tests/common/sql/hooks/test_dbapi.py b/providers/tests/common/sql/hooks/test_dbapi.py index 1f3f39aa451ab..57a74987dfe7f 100644 --- a/providers/tests/common/sql/hooks/test_dbapi.py +++ b/providers/tests/common/sql/hooks/test_dbapi.py @@ -28,6 +28,7 @@ from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.hooks.base import BaseHook from airflow.models import Connection +from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler, fetch_one_handler @@ -62,6 +63,10 @@ def get_connection(cls, conn_id: str) -> Connection: def get_conn(self): return conn + @property + def dialect(self): + return Dialect(self) + def get_db_log_messages(self, conn) -> None: return conn.get_messages() diff --git a/providers/tests/common/sql/hooks/test_sql.py b/providers/tests/common/sql/hooks/test_sql.py index 756663ca39c9a..cb7696b370156 100644 --- a/providers/tests/common/sql/hooks/test_sql.py +++ b/providers/tests/common/sql/hooks/test_sql.py @@ -18,6 +18,7 @@ # from __future__ import annotations +import inspect import logging import logging.config from unittest.mock import MagicMock @@ -25,11 +26,14 @@ import pytest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import Connection -from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler +from airflow.providers.common.sql.dialects.dialect import Dialect +from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler, resolve_dialects from airflow.utils.session import provide_session from providers.tests.common.sql.test_utils import mock_hook +from tests_common.test_utils.providers import get_provider_min_airflow_version TASK_ID = "sql-operator" HOST = "host" @@ -259,6 +263,34 @@ def test_placeholder_multiple_times_and_make_sure_connection_is_only_invoked_onc assert dbapi_hook.placeholder == "%s" assert dbapi_hook.connection_invocations == 1 + @pytest.mark.db_test + def test_dialect_name(self): + dbapi_hook = mock_hook(DbApiHook) + assert dbapi_hook.dialect_name == "default" + + @pytest.mark.db_test + def test_dialect(self): + dbapi_hook = mock_hook(DbApiHook) + assert isinstance(dbapi_hook.dialect, Dialect) + + @pytest.mark.db_test + def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_code(self): + """ + Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher + for this provider, you should remove the obsolete code in the get_dialects method of the DbApiHook + and remove this test. This test was added to make sure to not forget to remove the fallback code + for backward compatibility with Airflow 2.8.x which isn't need anymore once this provider depends on + Airflow 3.0.0 or higher. + """ + min_airflow_version = get_provider_min_airflow_version("apache-airflow-providers-common-sql") + + # Check if the current Airflow version is 3.0.0 or higher + if min_airflow_version[0] >= 3: + method_source = inspect.getsource(resolve_dialects) + raise AirflowProviderDeprecationWarning( + f"Check TODO's to remove obsolete code in resolve_dialects method:\n\r\n\r\t\t\t{method_source}" + ) + @pytest.mark.db_test def test_uri(self): dbapi_hook = mock_hook(DbApiHook) diff --git a/providers/tests/jdbc/hooks/test_jdbc.py b/providers/tests/jdbc/hooks/test_jdbc.py index ce4e526623440..646b3e9c09e10 100644 --- a/providers/tests/jdbc/hooks/test_jdbc.py +++ b/providers/tests/jdbc/hooks/test_jdbc.py @@ -44,17 +44,27 @@ def get_hook( hook_params=None, conn_params=None, + conn_type: str | None = None, login: str | None = "login", password: str | None = "password", host: str | None = "host", schema: str | None = "schema", port: int | None = 1234, + uri: str | None = None, ): hook_params = hook_params or {} conn_params = conn_params or {} connection = Connection( **{ - **dict(login=login, password=password, host=host, schema=schema, port=port), + **dict( + conn_type=conn_type, + login=login, + password=password, + host=host, + schema=schema, + port=port, + uri=uri, + ), **conn_params, } ) @@ -251,6 +261,19 @@ def test_get_sqlalchemy_engine_verify_creator_is_being_used(self): engine = jdbc_hook.get_sqlalchemy_engine() assert engine.connect().connection.connection == connection + def test_dialect_name(self): + jdbc_hook = get_hook( + conn_params=dict(extra={"sqlalchemy_scheme": "hana"}), + conn_type="jdbc", + login=None, + password=None, + host="localhost", + schema="sap", + port=30215, + ) + + assert jdbc_hook.dialect_name == "hana" + def test_get_conn_thread_safety(self): mock_conn = MagicMock() open_connections = 0 diff --git a/providers/tests/microsoft/mssql/dialects/__init__.py b/providers/tests/microsoft/mssql/dialects/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/tests/microsoft/mssql/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/tests/microsoft/mssql/dialects/test_mssql.py b/providers/tests/microsoft/mssql/dialects/test_mssql.py new file mode 100644 index 0000000000000..762c4c463dfdb --- /dev/null +++ b/providers/tests/microsoft/mssql/dialects/test_mssql.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +from sqlalchemy.engine import Inspector + +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect + + +class TestMsSqlDialect: + def setup_method(self): + inspector = MagicMock(spc=Inspector) + inspector.get_columns.side_effect = lambda table_name, schema: [ + {"name": "id", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ] + self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) + self.test_db_hook.run.side_effect = lambda *args: [("id",)] + self.test_db_hook._escape_column_name_format = '"{}"' + + def test_placeholder(self): + assert MsSqlDialect(self.test_db_hook).placeholder == "?" + + def test_get_column_names(self): + assert MsSqlDialect(self.test_db_hook).get_column_names("hollywood.actors") == [ + "id", + "name", + "firstname", + "age", + ] + + def test_get_target_fields(self): + assert MsSqlDialect(self.test_db_hook).get_target_fields("hollywood.actors") == [ + "name", + "firstname", + "age", + ] + + def test_get_primary_keys(self): + assert MsSqlDialect(self.test_db_hook).get_primary_keys("hollywood.actors") == ["id"] + + def test_generate_replace_sql(self): + values = [ + {"id": "id", "name": "Stallone", "firstname": "Sylvester", "age": "78"}, + {"id": "id", "name": "Statham", "firstname": "Jason", "age": "57"}, + {"id": "id", "name": "Li", "firstname": "Jet", "age": "61"}, + {"id": "id", "name": "Lundgren", "firstname": "Dolph", "age": "66"}, + {"id": "id", "name": "Norris", "firstname": "Chuck", "age": "84"}, + ] + target_fields = ["id", "name", "firstname", "age"] + sql = MsSqlDialect(self.test_db_hook).generate_replace_sql("hollywood.actors", values, target_fields) + assert ( + sql + == """ + MERGE INTO hollywood.actors WITH (ROWLOCK) AS target + USING (SELECT ? AS id, ? AS name, ? AS firstname, ? AS age) AS source + ON target.id = source.id + WHEN MATCHED THEN + UPDATE SET target.name = source.name, target.firstname = source.firstname, target.age = source.age + WHEN NOT MATCHED THEN + INSERT (id, name, firstname, age) VALUES (source.id, source.name, source.firstname, source.age); + """.strip() + ) diff --git a/providers/tests/microsoft/mssql/hooks/test_mssql.py b/providers/tests/microsoft/mssql/hooks/test_mssql.py index be8f921112a4a..7153edde85217 100644 --- a/providers/tests/microsoft/mssql/hooks/test_mssql.py +++ b/providers/tests/microsoft/mssql/hooks/test_mssql.py @@ -20,8 +20,11 @@ from unittest import mock import pytest +import sqlalchemy +from airflow.configuration import conf from airflow.models import Connection +from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect from providers.tests.microsoft.conftest import load_file @@ -30,9 +33,72 @@ except ImportError: pytest.skip("MSSQL not available", allow_module_level=True) +PYMSSQL_CONN = Connection( + conn_type="mssql", host="ip", schema="share", login="username", password="password", port=8081 +) +PYMSSQL_CONN_ALT = Connection( + conn_type="mssql", host="ip", schema="", login="username", password="password", port=8081 +) +PYMSSQL_CONN_ALT_1 = Connection( + conn_type="mssql", + host="ip", + schema="", + login="username", + password="password", + port=8081, + extra={"SQlalchemy_Scheme": "mssql+testdriver"}, +) +PYMSSQL_CONN_ALT_2 = Connection( + conn_type="mssql", + host="ip", + schema="", + login="username", + password="password", + port=8081, + extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"}, +) + + +def get_target_fields(self, table: str) -> list[str] | None: + return [ + "ReportRefreshDate", + "UserId", + "UserPrincipalName", + "LastActivityDate", + "IsDeleted", + "DeletedDate", + "AssignedProducts", + "TeamChatMessageCount", + "PrivateChatMessageCount", + "CallCount", + "MeetingCount", + "MeetingsOrganizedCount", + "MeetingsAttendedCount", + "AdHocMeetingsOrganizedCount", + "AdHocMeetingsAttendedCount", + "ScheduledOne-timeMeetingsOrganizedCount", + "ScheduledOne-timeMeetingsAttendedCount", + "ScheduledRecurringMeetingsOrganizedCount", + "ScheduledRecurringMeetingsAttendedCount", + "AudioDuration", + "VideoDuration", + "ScreenShareDuration", + "AudioDurationInSeconds", + "VideoDurationInSeconds", + "ScreenShareDurationInSeconds", + "HasOtherAction", + "UrgentMessages", + "PostMessages", + "TenantDisplayName", + "SharedChannelTenantDisplayNames", + "ReplyMessages", + "IsLicensed", + "ReportPeriod", + "LoadDate", + ] -@pytest.fixture -def get_primary_keys(): + +def get_primary_keys(self, table: str) -> list[str] | None: return [ "GroupDisplayName", "OwnerPrincipalName", @@ -80,6 +146,14 @@ def mssql_connections(): class TestMsSqlHook: + def setup_method(self): + MsSqlHook._resolve_target_fields = True + + def teardown_method(self, method): + MsSqlHook._resolve_target_fields = conf.getboolean( + "core", "dbapihook_resolve_target_fields", fallback=False + ) + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn") @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection") def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn, mssql_connections): @@ -161,88 +235,71 @@ def test_get_sqlalchemy_engine(self, get_connection, mssql_connections): hook.get_sqlalchemy_engine() @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") - def test_generate_insert_sql(self, get_connection, mssql_connections, get_primary_keys): - get_connection.return_value = mssql_connections["default"] + @mock.patch( + "airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect.get_target_fields", + get_target_fields, + ) + @mock.patch( + "airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect.get_primary_keys", + get_primary_keys, + ) + def test_generate_insert_sql(self, get_connection): + get_connection.return_value = PYMSSQL_CONN + + hook = MsSqlHook() + sql = hook._generate_insert_sql( + table="YAMMER_GROUPS_ACTIVITY_DETAIL", + values=[ + "2024-07-17", + "daa5b44c-80d6-4e22-85b5-a94e04cf7206", + "no-reply@microsoft.com", + "2024-07-17", + 0, + 0.0, + "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5", + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + "PT0S", + "PT0S", + "PT0S", + 0, + 0, + 0, + "Yes", + 0, + 0, + "APACHE", + 0.0, + 0, + "Yes", + 1, + "2024-07-17T00:00:00+00:00", + ], + replace=True, + ) + assert sql == load_file("resources", "replace.sql") + + def test_dialect_name(self): + hook = MsSqlHook() + assert hook.dialect_name == "mssql" + + def test_dialect(self): + hook = MsSqlHook() + assert isinstance(hook.dialect, MsSqlDialect) + def test_reserved_words(self): hook = MsSqlHook() - with mock.patch.object(hook, "get_primary_keys", return_value=get_primary_keys): - sql = hook._generate_insert_sql( - table="YAMMER_GROUPS_ACTIVITY_DETAIL", - values=[ - "2024-07-17", - "daa5b44c-80d6-4e22-85b5-a94e04cf7206", - "no-reply@microsoft.com", - "2024-07-17", - 0, - 0.0, - "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5", - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - "PT0S", - "PT0S", - "PT0S", - 0, - 0, - 0, - "Yes", - 0, - 0, - "APACHE", - 0.0, - 0, - "Yes", - 1, - "2024-07-17T00:00:00+00:00", - ], - target_fields=[ - "ReportRefreshDate", - "UserId", - "UserPrincipalName", - "LastActivityDate", - "IsDeleted", - "DeletedDate", - "AssignedProducts", - "TeamChatMessageCount", - "PrivateChatMessageCount", - "CallCount", - "MeetingCount", - "MeetingsOrganizedCount", - "MeetingsAttendedCount", - "AdHocMeetingsOrganizedCount", - "AdHocMeetingsAttendedCount", - "ScheduledOne-timeMeetingsOrganizedCount", - "ScheduledOne-timeMeetingsAttendedCount", - "ScheduledRecurringMeetingsOrganizedCount", - "ScheduledRecurringMeetingsAttendedCount", - "AudioDuration", - "VideoDuration", - "ScreenShareDuration", - "AudioDurationInSeconds", - "VideoDurationInSeconds", - "ScreenShareDurationInSeconds", - "HasOtherAction", - "UrgentMessages", - "PostMessages", - "TenantDisplayName", - "SharedChannelTenantDisplayNames", - "ReplyMessages", - "IsLicensed", - "ReportPeriod", - "LoadDate", - ], - replace=True, - ) - assert sql == load_file("resources", "replace.sql") + assert hook.reserved_words == sqlalchemy.dialects.mssql.base.RESERVED_WORDS @pytest.mark.db_test @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") diff --git a/providers/tests/microsoft/mssql/resources/replace.sql b/providers/tests/microsoft/mssql/resources/replace.sql index f8fb93b382ea6..07c7ec29e0188 100644 --- a/providers/tests/microsoft/mssql/resources/replace.sql +++ b/providers/tests/microsoft/mssql/resources/replace.sql @@ -17,10 +17,10 @@ under the License. */ -MERGE INTO YAMMER_GROUPS_ACTIVITY_DETAIL AS target - USING (SELECT %s AS ReportRefreshDate, %s AS UserId, %s AS UserPrincipalName, %s AS LastActivityDate, %s AS IsDeleted, %s AS DeletedDate, %s AS AssignedProducts, %s AS TeamChatMessageCount, %s AS PrivateChatMessageCount, %s AS CallCount, %s AS MeetingCount, %s AS MeetingsOrganizedCount, %s AS MeetingsAttendedCount, %s AS AdHocMeetingsOrganizedCount, %s AS AdHocMeetingsAttendedCount, %s AS ScheduledOne-timeMeetingsOrganizedCount, %s AS ScheduledOne-timeMeetingsAttendedCount, %s AS ScheduledRecurringMeetingsOrganizedCount, %s AS ScheduledRecurringMeetingsAttendedCount, %s AS AudioDuration, %s AS VideoDuration, %s AS ScreenShareDuration, %s AS AudioDurationInSeconds, %s AS VideoDurationInSeconds, %s AS ScreenShareDurationInSeconds, %s AS HasOtherAction, %s AS UrgentMessages, %s AS PostMessages, %s AS TenantDisplayName, %s AS SharedChannelTenantDisplayNames, %s AS ReplyMessages, %s AS IsLicensed, %s AS ReportPeriod, %s AS LoadDate) AS source - ON target.GroupDisplayName = source.GroupDisplayName AND target.OwnerPrincipalName = source.OwnerPrincipalName AND target.ReportPeriod = source.ReportPeriod AND target.ReportRefreshDate = source.ReportRefreshDate - WHEN MATCHED THEN - UPDATE SET target.UserId = source.UserId, target.UserPrincipalName = source.UserPrincipalName, target.LastActivityDate = source.LastActivityDate, target.IsDeleted = source.IsDeleted, target.DeletedDate = source.DeletedDate, target.AssignedProducts = source.AssignedProducts, target.TeamChatMessageCount = source.TeamChatMessageCount, target.PrivateChatMessageCount = source.PrivateChatMessageCount, target.CallCount = source.CallCount, target.MeetingCount = source.MeetingCount, target.MeetingsOrganizedCount = source.MeetingsOrganizedCount, target.MeetingsAttendedCount = source.MeetingsAttendedCount, target.AdHocMeetingsOrganizedCount = source.AdHocMeetingsOrganizedCount, target.AdHocMeetingsAttendedCount = source.AdHocMeetingsAttendedCount, target.ScheduledOne-timeMeetingsOrganizedCount = source.ScheduledOne-timeMeetingsOrganizedCount, target.ScheduledOne-timeMeetingsAttendedCount = source.ScheduledOne-timeMeetingsAttendedCount, target.ScheduledRecurringMeetingsOrganizedCount = source.ScheduledRecurringMeetingsOrganizedCount, target.ScheduledRecurringMeetingsAttendedCount = source.ScheduledRecurringMeetingsAttendedCount, target.AudioDuration = source.AudioDuration, target.VideoDuration = source.VideoDuration, target.ScreenShareDuration = source.ScreenShareDuration, target.AudioDurationInSeconds = source.AudioDurationInSeconds, target.VideoDurationInSeconds = source.VideoDurationInSeconds, target.ScreenShareDurationInSeconds = source.ScreenShareDurationInSeconds, target.HasOtherAction = source.HasOtherAction, target.UrgentMessages = source.UrgentMessages, target.PostMessages = source.PostMessages, target.TenantDisplayName = source.TenantDisplayName, target.SharedChannelTenantDisplayNames = source.SharedChannelTenantDisplayNames, target.ReplyMessages = source.ReplyMessages, target.IsLicensed = source.IsLicensed, target.LoadDate = source.LoadDate - WHEN NOT MATCHED THEN - INSERT (ReportRefreshDate, UserId, UserPrincipalName, LastActivityDate, IsDeleted, DeletedDate, AssignedProducts, TeamChatMessageCount, PrivateChatMessageCount, CallCount, MeetingCount, MeetingsOrganizedCount, MeetingsAttendedCount, AdHocMeetingsOrganizedCount, AdHocMeetingsAttendedCount, ScheduledOne-timeMeetingsOrganizedCount, ScheduledOne-timeMeetingsAttendedCount, ScheduledRecurringMeetingsOrganizedCount, ScheduledRecurringMeetingsAttendedCount, AudioDuration, VideoDuration, ScreenShareDuration, AudioDurationInSeconds, VideoDurationInSeconds, ScreenShareDurationInSeconds, HasOtherAction, UrgentMessages, PostMessages, TenantDisplayName, SharedChannelTenantDisplayNames, ReplyMessages, IsLicensed, ReportPeriod, LoadDate) VALUES (source.ReportRefreshDate, source.UserId, source.UserPrincipalName, source.LastActivityDate, source.IsDeleted, source.DeletedDate, source.AssignedProducts, source.TeamChatMessageCount, source.PrivateChatMessageCount, source.CallCount, source.MeetingCount, source.MeetingsOrganizedCount, source.MeetingsAttendedCount, source.AdHocMeetingsOrganizedCount, source.AdHocMeetingsAttendedCount, source.ScheduledOne-timeMeetingsOrganizedCount, source.ScheduledOne-timeMeetingsAttendedCount, source.ScheduledRecurringMeetingsOrganizedCount, source.ScheduledRecurringMeetingsAttendedCount, source.AudioDuration, source.VideoDuration, source.ScreenShareDuration, source.AudioDurationInSeconds, source.VideoDurationInSeconds, source.ScreenShareDurationInSeconds, source.HasOtherAction, source.UrgentMessages, source.PostMessages, source.TenantDisplayName, source.SharedChannelTenantDisplayNames, source.ReplyMessages, source.IsLicensed, source.ReportPeriod, source.LoadDate); +MERGE INTO YAMMER_GROUPS_ACTIVITY_DETAIL WITH (ROWLOCK) AS target + USING (SELECT %s AS ReportRefreshDate, %s AS UserId, %s AS UserPrincipalName, %s AS LastActivityDate, %s AS IsDeleted, %s AS DeletedDate, %s AS AssignedProducts, %s AS TeamChatMessageCount, %s AS PrivateChatMessageCount, %s AS CallCount, %s AS MeetingCount, %s AS MeetingsOrganizedCount, %s AS MeetingsAttendedCount, %s AS AdHocMeetingsOrganizedCount, %s AS AdHocMeetingsAttendedCount, %s AS ScheduledOne-timeMeetingsOrganizedCount, %s AS ScheduledOne-timeMeetingsAttendedCount, %s AS ScheduledRecurringMeetingsOrganizedCount, %s AS ScheduledRecurringMeetingsAttendedCount, %s AS AudioDuration, %s AS VideoDuration, %s AS ScreenShareDuration, %s AS AudioDurationInSeconds, %s AS VideoDurationInSeconds, %s AS ScreenShareDurationInSeconds, %s AS HasOtherAction, %s AS UrgentMessages, %s AS PostMessages, %s AS TenantDisplayName, %s AS SharedChannelTenantDisplayNames, %s AS ReplyMessages, %s AS IsLicensed, %s AS ReportPeriod, %s AS LoadDate) AS source + ON target.GroupDisplayName = source.GroupDisplayName AND target.OwnerPrincipalName = source.OwnerPrincipalName AND target.ReportPeriod = source.ReportPeriod AND target.ReportRefreshDate = source.ReportRefreshDate + WHEN MATCHED THEN + UPDATE SET target.UserId = source.UserId, target.UserPrincipalName = source.UserPrincipalName, target.LastActivityDate = source.LastActivityDate, target.IsDeleted = source.IsDeleted, target.DeletedDate = source.DeletedDate, target.AssignedProducts = source.AssignedProducts, target.TeamChatMessageCount = source.TeamChatMessageCount, target.PrivateChatMessageCount = source.PrivateChatMessageCount, target.CallCount = source.CallCount, target.MeetingCount = source.MeetingCount, target.MeetingsOrganizedCount = source.MeetingsOrganizedCount, target.MeetingsAttendedCount = source.MeetingsAttendedCount, target.AdHocMeetingsOrganizedCount = source.AdHocMeetingsOrganizedCount, target.AdHocMeetingsAttendedCount = source.AdHocMeetingsAttendedCount, target.ScheduledOne-timeMeetingsOrganizedCount = source.ScheduledOne-timeMeetingsOrganizedCount, target.ScheduledOne-timeMeetingsAttendedCount = source.ScheduledOne-timeMeetingsAttendedCount, target.ScheduledRecurringMeetingsOrganizedCount = source.ScheduledRecurringMeetingsOrganizedCount, target.ScheduledRecurringMeetingsAttendedCount = source.ScheduledRecurringMeetingsAttendedCount, target.AudioDuration = source.AudioDuration, target.VideoDuration = source.VideoDuration, target.ScreenShareDuration = source.ScreenShareDuration, target.AudioDurationInSeconds = source.AudioDurationInSeconds, target.VideoDurationInSeconds = source.VideoDurationInSeconds, target.ScreenShareDurationInSeconds = source.ScreenShareDurationInSeconds, target.HasOtherAction = source.HasOtherAction, target.UrgentMessages = source.UrgentMessages, target.PostMessages = source.PostMessages, target.TenantDisplayName = source.TenantDisplayName, target.SharedChannelTenantDisplayNames = source.SharedChannelTenantDisplayNames, target.ReplyMessages = source.ReplyMessages, target.IsLicensed = source.IsLicensed, target.LoadDate = source.LoadDate + WHEN NOT MATCHED THEN + INSERT (ReportRefreshDate, UserId, UserPrincipalName, LastActivityDate, IsDeleted, DeletedDate, AssignedProducts, TeamChatMessageCount, PrivateChatMessageCount, CallCount, MeetingCount, MeetingsOrganizedCount, MeetingsAttendedCount, AdHocMeetingsOrganizedCount, AdHocMeetingsAttendedCount, ScheduledOne-timeMeetingsOrganizedCount, ScheduledOne-timeMeetingsAttendedCount, ScheduledRecurringMeetingsOrganizedCount, ScheduledRecurringMeetingsAttendedCount, AudioDuration, VideoDuration, ScreenShareDuration, AudioDurationInSeconds, VideoDurationInSeconds, ScreenShareDurationInSeconds, HasOtherAction, UrgentMessages, PostMessages, TenantDisplayName, SharedChannelTenantDisplayNames, ReplyMessages, IsLicensed, ReportPeriod, LoadDate) VALUES (source.ReportRefreshDate, source.UserId, source.UserPrincipalName, source.LastActivityDate, source.IsDeleted, source.DeletedDate, source.AssignedProducts, source.TeamChatMessageCount, source.PrivateChatMessageCount, source.CallCount, source.MeetingCount, source.MeetingsOrganizedCount, source.MeetingsAttendedCount, source.AdHocMeetingsOrganizedCount, source.AdHocMeetingsAttendedCount, source.ScheduledOne-timeMeetingsOrganizedCount, source.ScheduledOne-timeMeetingsAttendedCount, source.ScheduledRecurringMeetingsOrganizedCount, source.ScheduledRecurringMeetingsAttendedCount, source.AudioDuration, source.VideoDuration, source.ScreenShareDuration, source.AudioDurationInSeconds, source.VideoDurationInSeconds, source.ScreenShareDurationInSeconds, source.HasOtherAction, source.UrgentMessages, source.PostMessages, source.TenantDisplayName, source.SharedChannelTenantDisplayNames, source.ReplyMessages, source.IsLicensed, source.ReportPeriod, source.LoadDate); diff --git a/providers/tests/mysql/hooks/test_mysql.py b/providers/tests/mysql/hooks/test_mysql.py index 0e21047b107b5..6facc6f4b44e6 100644 --- a/providers/tests/mysql/hooks/test_mysql.py +++ b/providers/tests/mysql/hooks/test_mysql.py @@ -24,6 +24,7 @@ from unittest import mock import pytest +import sqlalchemy from airflow.models import Connection from airflow.models.dag import DAG @@ -31,18 +32,20 @@ try: import MySQLdb.cursors - from airflow.providers.mysql.hooks.mysql import MySqlHook + MYSQL_AVAILABLE = True except ImportError: - pytest.skip("MySQL not available", allow_module_level=True) - + MYSQL_AVAILABLE = False +from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.utils import timezone from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces SSL_DICT = {"cert": "/tmp/client-cert.pem", "ca": "/tmp/server-ca.pem", "key": "/tmp/client-key.pem"} +INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type, description, host, `schema`, login, password, port, is_encrypted, is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)" +@pytest.mark.skipif(not MYSQL_AVAILABLE, reason="MySQL not available") class TestMySqlHookConn: def setup_method(self): self.connection = Connection( @@ -223,6 +226,7 @@ def autocommit(self, autocommit): self._autocommit = autocommit +@pytest.mark.db_test class TestMySqlHook: def setup_method(self): self.cur = mock.MagicMock(rowcount=0) @@ -327,6 +331,80 @@ def test_bulk_load_custom(self): ), ) + def test_reserved_words(self): + hook = MySqlHook() + assert hook.reserved_words == sqlalchemy.dialects.mysql.reserved_words.RESERVED_WORDS_MYSQL + + def test_generate_insert_sql_without_already_escaped_column_name(self): + values = [ + "1", + "mssql_conn", + "mssql", + "MSSQL connection", + "localhost", + "airflow", + "admin", + "admin", + 1433, + False, + False, + {}, + ] + target_fields = [ + "id", + "conn_id", + "conn_type", + "description", + "host", + "schema", + "login", + "password", + "port", + "is_encrypted", + "is_extra_encrypted", + "extra", + ] + hook = MySqlHook() + assert ( + hook._generate_insert_sql(table="connection", values=values, target_fields=target_fields) + == INSERT_SQL_STATEMENT + ) + + def test_generate_insert_sql_with_already_escaped_column_name(self): + values = [ + "1", + "mssql_conn", + "mssql", + "MSSQL connection", + "localhost", + "airflow", + "admin", + "admin", + 1433, + False, + False, + {}, + ] + target_fields = [ + "id", + "conn_id", + "conn_type", + "description", + "host", + "`schema`", + "login", + "password", + "port", + "is_encrypted", + "is_extra_encrypted", + "extra", + ] + hook = MySqlHook() + assert ( + hook._generate_insert_sql(table="connection", values=values, target_fields=target_fields) + == INSERT_SQL_STATEMENT + ) + DEFAULT_DATE = timezone.datetime(2015, 1, 1) DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() @@ -348,6 +426,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @pytest.mark.backend("mysql") +@pytest.mark.skipif(not MYSQL_AVAILABLE, reason="MySQL not available") class TestMySql: def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} diff --git a/providers/tests/odbc/hooks/test_odbc.py b/providers/tests/odbc/hooks/test_odbc.py index 5d2e195dcc640..038bb4e1c4ff4 100644 --- a/providers/tests/odbc/hooks/test_odbc.py +++ b/providers/tests/odbc/hooks/test_odbc.py @@ -27,6 +27,7 @@ import pyodbc import pytest +from sqlalchemy.exc import ArgumentError from airflow.providers.odbc.hooks.odbc import OdbcHook @@ -78,6 +79,10 @@ class PyodbcRow(metaclass=PyodbcRowMeta): return PyodbcRow +def raise_argument_error(): + raise ArgumentError() + + class TestOdbcHook: def test_driver_in_extra_not_used(self): conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", Fake_Param="Fake Param"))) @@ -342,6 +347,26 @@ def test_query_no_handler_return_none(self): result = hook.run("SQL") assert result is None + def test_dialect_name_when_resolved_from_sqlalchemy_uri(self): + hook = mock_hook(OdbcHook) + assert hook.dialect_name == "mssql" + + def test_dialect_name_when_resolved_from_conn_type(self): + hook = mock_hook(OdbcHook) + hook.get_conn().conn_type = "sqlite" + hook.get_uri = raise_argument_error + assert hook.dialect_name == "default" + + def test_dialect_name_when_resolved_from_sqlalchemy_scheme_in_extra(self): + hook = mock_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme": "mssql+pymssql"}}) + hook.get_uri = raise_argument_error + assert hook.dialect_name == "mssql" + + def test_dialect_name_when_resolved_from_dialect_in_extra(self): + hook = mock_hook(OdbcHook, conn_params={"extra": {"dialect": "oracle"}}) + hook.get_uri = raise_argument_error + assert hook.dialect_name == "oracle" + def test_get_sqlalchemy_engine_verify_creator_is_being_used(self): hook = mock_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme": "sqlite"}}) diff --git a/providers/tests/postgres/dialects/__init__.py b/providers/tests/postgres/dialects/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/tests/postgres/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/tests/postgres/dialects/test_postgres.py b/providers/tests/postgres/dialects/test_postgres.py new file mode 100644 index 0000000000000..ab4968a66456b --- /dev/null +++ b/providers/tests/postgres/dialects/test_postgres.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +from sqlalchemy.engine import Inspector + +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.postgres.dialects.postgres import PostgresDialect + + +class TestPostgresDialect: + def setup_method(self): + inspector = MagicMock(spc=Inspector) + inspector.get_columns.side_effect = lambda table_name, schema: [ + {"name": "id", "identity": True}, + {"name": "name"}, + {"name": "firstname"}, + {"name": "age"}, + ] + + def get_records(sql, parameters): + assert isinstance(sql, str) + assert "hollywood" in parameters, "Missing 'schema' in parameters" + assert "actors" in parameters, "Missing 'table' in parameters" + return [("id",)] + + self.test_db_hook = MagicMock(placeholder="?", inspector=inspector, spec=DbApiHook) + self.test_db_hook.get_records.side_effect = get_records + self.test_db_hook._insert_statement_format = "INSERT INTO {} {} VALUES ({})" + self.test_db_hook._escape_column_name_format = '"{}"' + + def test_placeholder(self): + assert PostgresDialect(self.test_db_hook).placeholder == "?" + + def test_get_column_names(self): + assert PostgresDialect(self.test_db_hook).get_column_names("hollywood.actors") == [ + "id", + "name", + "firstname", + "age", + ] + + def test_get_target_fields(self): + assert PostgresDialect(self.test_db_hook).get_target_fields("hollywood.actors") == [ + "name", + "firstname", + "age", + ] + + def test_get_primary_keys(self): + assert PostgresDialect(self.test_db_hook).get_primary_keys("hollywood.actors") == ["id"] + + def test_generate_replace_sql(self): + values = [ + {"id": "id", "name": "Stallone", "firstname": "Sylvester", "age": "78"}, + {"id": "id", "name": "Statham", "firstname": "Jason", "age": "57"}, + {"id": "id", "name": "Li", "firstname": "Jet", "age": "61"}, + {"id": "id", "name": "Lundgren", "firstname": "Dolph", "age": "66"}, + {"id": "id", "name": "Norris", "firstname": "Chuck", "age": "84"}, + ] + target_fields = ["id", "name", "firstname", "age"] + sql = PostgresDialect(self.test_db_hook).generate_replace_sql( + "hollywood.actors", values, target_fields + ) + assert ( + sql + == """ + INSERT INTO hollywood.actors (id, name, firstname, age) VALUES (?,?,?,?,?) ON CONFLICT (id) DO UPDATE SET name = excluded.name, firstname = excluded.firstname, age = excluded.age + """.strip() + ) diff --git a/providers/tests/postgres/hooks/test_postgres.py b/providers/tests/postgres/hooks/test_postgres.py index 76206d5795866..2483dba913293 100644 --- a/providers/tests/postgres/hooks/test_postgres.py +++ b/providers/tests/postgres/hooks/test_postgres.py @@ -24,12 +24,16 @@ import psycopg2.extras import pytest +import sqlalchemy from airflow.exceptions import AirflowException from airflow.models import Connection +from airflow.providers.postgres.dialects.postgres import PostgresDialect from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.types import NOTSET +INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type, description, host, {}, login, password, port, is_encrypted, is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)" + class TestPostgresHookConn: def setup_method(self): @@ -645,3 +649,81 @@ def test_log_db_messages_by_db_proc(self, caplog): assert "NOTICE: Message from db: 42" in caplog.text finally: hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)") + + def test_dialect_name(self): + assert self.db_hook.dialect_name == "postgresql" + + def test_dialect(self): + assert isinstance(self.db_hook.dialect, PostgresDialect) + + def test_reserved_words(self): + hook = PostgresHook() + assert hook.reserved_words == sqlalchemy.dialects.postgresql.base.RESERVED_WORDS + + def test_generate_insert_sql_without_already_escaped_column_name(self): + values = [ + "1", + "mssql_conn", + "mssql", + "MSSQL connection", + "localhost", + "airflow", + "admin", + "admin", + 1433, + False, + False, + {}, + ] + target_fields = [ + "id", + "conn_id", + "conn_type", + "description", + "host", + "schema", + "login", + "password", + "port", + "is_encrypted", + "is_extra_encrypted", + "extra", + ] + hook = PostgresHook() + assert hook._generate_insert_sql( + table="connection", values=values, target_fields=target_fields + ) == INSERT_SQL_STATEMENT.format("schema") + + def test_generate_insert_sql_with_already_escaped_column_name(self): + values = [ + "1", + "mssql_conn", + "mssql", + "MSSQL connection", + "localhost", + "airflow", + "admin", + "admin", + 1433, + False, + False, + {}, + ] + target_fields = [ + "id", + "conn_id", + "conn_type", + "description", + "host", + '"schema"', + "login", + "password", + "port", + "is_encrypted", + "is_extra_encrypted", + "extra", + ] + hook = PostgresHook() + assert hook._generate_insert_sql( + table="connection", values=values, target_fields=target_fields + ) == INSERT_SQL_STATEMENT.format('"schema"') diff --git a/providers/tests/teradata/hooks/test_teradata.py b/providers/tests/teradata/hooks/test_teradata.py index f2bcb1ee269ba..10754555b1a4d 100644 --- a/providers/tests/teradata/hooks/test_teradata.py +++ b/providers/tests/teradata/hooks/test_teradata.py @@ -28,6 +28,7 @@ class TestTeradataHook: def setup_method(self): self.connection = Connection( + conn_id="teradata_conn_id", conn_type="teradata", login="login", password="password", @@ -43,12 +44,14 @@ def setup_method(self): conn = self.conn class UnitTestTeradataHook(TeradataHook): - conn_name_attr = "teradata_conn_id" - def get_conn(self): return conn - self.test_db_hook = UnitTestTeradataHook() + @classmethod + def get_connection(cls, conn_id: str) -> Connection: + return conn + + self.test_db_hook = UnitTestTeradataHook(teradata_conn_id="teradata_conn_id") @mock.patch("teradatasql.connect") def test_get_conn(self, mock_connect): diff --git a/tests/always/test_providers_manager.py b/tests/always/test_providers_manager.py index 4af0c72971363..a808aedbb80a7 100644 --- a/tests/always/test_providers_manager.py +++ b/tests/always/test_providers_manager.py @@ -446,7 +446,8 @@ def test_auth_managers(self): def test_dialects(self): provider_manager = ProvidersManager() dialect_class_names = list(provider_manager.dialects) - assert len(dialect_class_names) == 0 + assert len(dialect_class_names) == 3 + assert dialect_class_names == ["default", "mssql", "postgresql"] @patch("airflow.providers_manager.import_string") def test_optional_feature_no_warning(self, mock_importlib_import_string):