Skip to content

Commit

Permalink
[DOP-22425] Run DB.check() on both driver and executor
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Feb 21, 2025
1 parent 7437eb3 commit 30c332e
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 28 deletions.
2 changes: 2 additions & 0 deletions docs/changelog/next_release/346.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Now ``DB.check()`` will test connection availability not only on Spark driver, but also from some Spark executor.
This allows to fail immediately if Spark driver host has network access to target DB, but Spark executors haven't.
28 changes: 28 additions & 0 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_client_info,
get_executor_total_cores,
get_spark_version,
override_job_description,
)
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
Expand Down Expand Up @@ -185,6 +186,9 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338
CONNECTIONS_EXCEPTION_LIMIT: ClassVar[int] = 100

_CHECK_QUERY: ClassVar[str] = "SELECT 1"
# any small table with always present in db, and which any user can access
# https://www.greenplumdba.com/pg-catalog-tables-and-views
_CHECK_DUMMY_TABLE: ClassVar[str] = "pg_catalog.gp_id"

@slot
@classmethod
Expand Down Expand Up @@ -303,6 +307,30 @@ def jdbc_params(self) -> dict:
result.update(self.jdbc_custom_params)
return result

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
with override_job_description(self.spark, f"{self}.check()"):
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))

read_options = self._get_connector_params(self._CHECK_DUMMY_TABLE)
read_options["num_partitions"] = 1 # do not require gp_segment_id column in table
df = self.spark.read.format("greenplum").options(**read_options).load()
df.take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def read_source_as_df(
self,
Expand Down
4 changes: 2 additions & 2 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def check(self):

try:
with override_job_description(self.spark, f"{self}.check()"):
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
self._execute_sql(self._CHECK_QUERY).take(1)
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down Expand Up @@ -379,7 +379,7 @@ def get_min_max_values(
hint=hint,
)

log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log.info("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, query)

df = self._execute_sql(query)
Expand Down
19 changes: 19 additions & 0 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ def _check_java_class_imported(cls, spark):
raise ValueError(msg) from e
return spark

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
with override_job_description(self.spark, f"{self}.check()"):
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))
self._query_on_executor(self._CHECK_QUERY, self.SQLOptions(fetchsize=1)).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def sql(
self,
Expand Down
18 changes: 0 additions & 18 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class JDBCMixin:
ExecuteOptions = JDBCExecuteOptions

DRIVER: ClassVar[str]
_CHECK_QUERY: ClassVar[str] = "SELECT 1"

@property
@abstractmethod
Expand Down Expand Up @@ -141,23 +140,6 @@ def __enter__(self):
def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: U101
self.close()

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def fetch(
self,
Expand Down
16 changes: 14 additions & 2 deletions onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_client_info, get_spark_version, stringify
from onetl._util.spark import (
get_client_info,
get_spark_version,
override_job_description,
stringify,
)
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.kafka.dialect import KafkaDialect
Expand Down Expand Up @@ -223,7 +228,14 @@ def check(self):
self._log_parameters()

try:
self._get_topics()
with override_job_description(self.spark, f"{self}.check()"):
self._get_topics()

read_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()}
# We need to read just any topic allowed to check if Kafka is alive
read_options["subscribePattern"] = ".*"
self.spark.read.format("kafka").options(**read_options).load().take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
15 changes: 14 additions & 1 deletion onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import warnings
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar
from urllib import parse as parser

from etl_entities.instance import Host
Expand Down Expand Up @@ -129,6 +129,10 @@ class MongoDB(DBConnection):
PipelineOptions = MongoDBPipelineOptions
Extra = MongoDBExtra

# any small collection with always present in db, and which any user can access
# https://www.mongodb.com/docs/manual/reference/system-collections/
_CHECK_DUMMY_COLLECTION: ClassVar[str] = "admin.system.version"

@slot
@classmethod
def get_packages(
Expand Down Expand Up @@ -373,6 +377,15 @@ def check(self):
jvm = self.spark._jvm # type: ignore
client = jvm.com.mongodb.client.MongoClients.create(self.connection_url)
list(client.listDatabaseNames().iterator())

with override_job_description(self.spark, f"{self}.check()"):
read_options = {
"connection.uri": self.connection_url,
"database": self.database,
"collection": self._CHECK_DUMMY_COLLECTION,
}
self.spark.read.format("mongodb").options(**read_options).load().take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_spark_metrics_recorder_postgres_read(spark, processing, load_table_data

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert metrics.input.read_rows == rows
# +1 is just postgres.check()
assert metrics.input.read_rows == rows + 1
# JDBC does not provide information about data size
assert not metrics.input.read_bytes

Expand All @@ -64,7 +65,8 @@ def test_spark_metrics_recorder_postgres_read_empty_source(spark, processing, pr

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.input.read_rows
# 1 is just postgres.check()
assert metrics.input.read_rows == 1


def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, processing, load_table_data):
Expand All @@ -89,7 +91,8 @@ def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, proces

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.input.read_rows
# 1 is just postgres.check()
assert metrics.input.read_rows == 1


def test_spark_metrics_recorder_postgres_sql(spark, processing, load_table_data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_oracle_connection_sql(spark, processing, load_table_data, suffix):
filtered_df = table_df[table_df.ID_INT < 50]
processing.assert_equal_df(df=df, other_frame=filtered_df, order_by="id_int")

# client info is expectedc
# client info is expected
df = oracle.sql(f"SELECT program FROM v$session WHERE program LIKE '%onETL%'{suffix}")
client_info = df.collect()[0][0]
assert client_info.startswith("local-")
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_oracle_connection_fetch(spark, processing, load_table_data, suffix):
filtered_df = table_df[table_df.ID_INT < 50]
processing.assert_equal_df(df=df, other_frame=filtered_df, order_by="id_int")

# client info is expectedc
# client info is expected
df = oracle.fetch(f"SELECT program FROM v$session WHERE program LIKE '%onETL%'{suffix}")
client_info = df.collect()[0][0]
assert client_info.startswith("local-")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_teradata_connection_check(spark, mocker, caplog):
password = secrets.token_hex()

mocker.patch.object(Teradata, "_query_optional_on_driver")
mocker.patch.object(Teradata, "_query_on_executor")

teradata = Teradata(
host=host,
Expand Down

0 comments on commit 30c332e

Please sign in to comment.