diff --git a/docs/changelog/next_release/346.feature.rst b/docs/changelog/next_release/346.feature.rst new file mode 100644 index 00000000..3d1e4603 --- /dev/null +++ b/docs/changelog/next_release/346.feature.rst @@ -0,0 +1,2 @@ +Now ``DB.check()`` will test connection availability not only on Spark driver, but also from some Spark executor. +This allows to fail immediately if Spark driver host has network access to target DB, but Spark executors haven't. diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index 407e0527..005fe2e7 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -25,6 +25,7 @@ get_client_info, get_executor_total_cores, get_spark_version, + override_job_description, ) from onetl._util.version import Version from onetl.connection.db_connection.db_connection import DBConnection @@ -185,6 +186,9 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338 CONNECTIONS_EXCEPTION_LIMIT: ClassVar[int] = 100 _CHECK_QUERY: ClassVar[str] = "SELECT 1" + # any small table with always present in db, and which any user can access + # https://www.greenplumdba.com/pg-catalog-tables-and-views + _CHECK_DUMMY_TABLE: ClassVar[str] = "pg_catalog.gp_id" @slot @classmethod @@ -303,6 +307,30 @@ def jdbc_params(self) -> dict: result.update(self.jdbc_custom_params) return result + @slot + def check(self): + log.info("|%s| Checking connection availability...", self.__class__.__name__) + self._log_parameters() # type: ignore + + log.debug("|%s| Executing SQL query:", self.__class__.__name__) + log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) + + try: + with override_job_description(self.spark, f"{self}.check()"): + self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1)) + + read_options = self._get_connector_params(self._CHECK_DUMMY_TABLE) + read_options["num_partitions"] = 1 # do not require gp_segment_id column in table + df = self.spark.read.format("greenplum").options(**read_options).load() + df.take(1) + + log.info("|%s| Connection is available.", self.__class__.__name__) + except Exception as e: + log.exception("|%s| Connection is unavailable", self.__class__.__name__) + raise RuntimeError("Connection is unavailable") from e + + return self + @slot def read_source_as_df( self, diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 4fe1270d..124b3eaf 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -180,7 +180,7 @@ def check(self): try: with override_job_description(self.spark, f"{self}.check()"): - self._execute_sql(self._CHECK_QUERY).limit(1).collect() + self._execute_sql(self._CHECK_QUERY).take(1) log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) @@ -379,7 +379,7 @@ def get_min_max_values( hint=hint, ) - log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) + log.info("|%s| Executing SQL query:", self.__class__.__name__) log_lines(log, query) df = self._execute_sql(query) diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index cb1e16a5..bb08f2b4 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -89,6 +89,25 @@ def _check_java_class_imported(cls, spark): raise ValueError(msg) from e return spark + @slot + def check(self): + log.info("|%s| Checking connection availability...", self.__class__.__name__) + self._log_parameters() # type: ignore + + log.debug("|%s| Executing SQL query:", self.__class__.__name__) + log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) + + try: + with override_job_description(self.spark, f"{self}.check()"): + self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1)) + self._query_on_executor(self._CHECK_QUERY, self.SQLOptions(fetchsize=1)).collect() + log.info("|%s| Connection is available.", self.__class__.__name__) + except Exception as e: + log.exception("|%s| Connection is unavailable", self.__class__.__name__) + raise RuntimeError("Connection is unavailable") from e + + return self + @slot def sql( self, diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index d2bea9b1..22c6ec70 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -77,7 +77,6 @@ class JDBCMixin: ExecuteOptions = JDBCExecuteOptions DRIVER: ClassVar[str] - _CHECK_QUERY: ClassVar[str] = "SELECT 1" @property @abstractmethod @@ -141,23 +140,6 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: U101 self.close() - @slot - def check(self): - log.info("|%s| Checking connection availability...", self.__class__.__name__) - self._log_parameters() # type: ignore - - log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__) - log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) - - try: - self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1)) - log.info("|%s| Connection is available.", self.__class__.__name__) - except Exception as e: - log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e - - return self - @slot def fetch( self, diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index 9cb335cc..28fff3e0 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -16,7 +16,12 @@ from onetl._util.java import try_import_java_class from onetl._util.scala import get_default_scala_version -from onetl._util.spark import get_client_info, get_spark_version, stringify +from onetl._util.spark import ( + get_client_info, + get_spark_version, + override_job_description, + stringify, +) from onetl._util.version import Version from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.kafka.dialect import KafkaDialect @@ -223,7 +228,14 @@ def check(self): self._log_parameters() try: - self._get_topics() + with override_job_description(self.spark, f"{self}.check()"): + self._get_topics() + + read_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()} + # We need to read just any topic allowed to check if Kafka is alive + read_options["subscribePattern"] = ".*" + self.spark.read.format("kafka").options(**read_options).load().take(1) + log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 4f455cec..c2164d42 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -5,7 +5,7 @@ import json import logging import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from urllib import parse as parser from etl_entities.instance import Host @@ -129,6 +129,10 @@ class MongoDB(DBConnection): PipelineOptions = MongoDBPipelineOptions Extra = MongoDBExtra + # any small collection with always present in db, and which any user can access + # https://www.mongodb.com/docs/manual/reference/system-collections/ + _CHECK_DUMMY_COLLECTION: ClassVar[str] = "admin.system.version" + @slot @classmethod def get_packages( @@ -373,6 +377,15 @@ def check(self): jvm = self.spark._jvm # type: ignore client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) list(client.listDatabaseNames().iterator()) + + with override_job_description(self.spark, f"{self}.check()"): + read_options = { + "connection.uri": self.connection_url, + "database": self.database, + "collection": self._CHECK_DUMMY_COLLECTION, + } + self.spark.read.format("mongodb").options(**read_options).load().take(1) + log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py index aa391eb8..58fae7c4 100644 --- a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py @@ -38,7 +38,8 @@ def test_spark_metrics_recorder_postgres_read(spark, processing, load_table_data time.sleep(0.1) # sleep to fetch late metrics from SparkListener metrics = recorder.metrics() - assert metrics.input.read_rows == rows + # +1 is just postgres.check() + assert metrics.input.read_rows == rows + 1 # JDBC does not provide information about data size assert not metrics.input.read_bytes @@ -64,7 +65,8 @@ def test_spark_metrics_recorder_postgres_read_empty_source(spark, processing, pr time.sleep(0.1) # sleep to fetch late metrics from SparkListener metrics = recorder.metrics() - assert not metrics.input.read_rows + # 1 is just postgres.check() + assert metrics.input.read_rows == 1 def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, processing, load_table_data): @@ -89,7 +91,8 @@ def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, proces time.sleep(0.1) # sleep to fetch late metrics from SparkListener metrics = recorder.metrics() - assert not metrics.input.read_rows + # 1 is just postgres.check() + assert metrics.input.read_rows == 1 def test_spark_metrics_recorder_postgres_sql(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py index cee16267..0927d2a5 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py @@ -103,7 +103,7 @@ def test_oracle_connection_sql(spark, processing, load_table_data, suffix): filtered_df = table_df[table_df.ID_INT < 50] processing.assert_equal_df(df=df, other_frame=filtered_df, order_by="id_int") - # client info is expectedc + # client info is expected df = oracle.sql(f"SELECT program FROM v$session WHERE program LIKE '%onETL%'{suffix}") client_info = df.collect()[0][0] assert client_info.startswith("local-") @@ -142,7 +142,7 @@ def test_oracle_connection_fetch(spark, processing, load_table_data, suffix): filtered_df = table_df[table_df.ID_INT < 50] processing.assert_equal_df(df=df, other_frame=filtered_df, order_by="id_int") - # client info is expectedc + # client info is expected df = oracle.fetch(f"SELECT program FROM v$session WHERE program LIKE '%onETL%'{suffix}") client_info = df.collect()[0][0] assert client_info.startswith("local-") diff --git a/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py b/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py index f3224025..d905d634 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py @@ -16,6 +16,7 @@ def test_teradata_connection_check(spark, mocker, caplog): password = secrets.token_hex() mocker.patch.object(Teradata, "_query_optional_on_driver") + mocker.patch.object(Teradata, "_query_on_executor") teradata = Teradata( host=host,