Skip to content

Commit

Permalink
[DOP-22425] Run DB.check() on both driver and executor
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Feb 19, 2025
1 parent 10e221c commit 4d06b98
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 30 deletions.
2 changes: 2 additions & 0 deletions docs/changelog/next_release/346.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Now ``DB.check()`` will test connection availability not only on Spark driver, but also from some Spark executor.
This allows to fail immediately if Spark driver host has network access to target DB, but Spark executors haven't.
27 changes: 27 additions & 0 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_client_info,
get_executor_total_cores,
get_spark_version,
override_job_description,
)
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
Expand Down Expand Up @@ -185,6 +186,8 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338
CONNECTIONS_EXCEPTION_LIMIT: ClassVar[int] = 100

_CHECK_QUERY: ClassVar[str] = "SELECT 1"
# small table, always present on cluster
_CHECK_DUMMY_TABLE: ClassVar[str] = "pg_catalog.gp_id"

@slot
@classmethod
Expand Down Expand Up @@ -303,6 +306,30 @@ def jdbc_params(self) -> dict:
result.update(self.jdbc_custom_params)
return result

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
with override_job_description(self.spark, f"{self}.check()"):
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))

read_options = self._get_connector_params(self._CHECK_DUMMY_TABLE)
read_options["num_partitions"] = 1 # do not require gp_segment_id column in table
df = self.spark.read.format("greenplum").options(**read_options).load()
df.take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def read_source_as_df(
self,
Expand Down
4 changes: 2 additions & 2 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def check(self):

try:
with override_job_description(self.spark, f"{self}.check()"):
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
self._execute_sql(self._CHECK_QUERY).take(1)
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down Expand Up @@ -379,7 +379,7 @@ def get_min_max_values(
hint=hint,
)

log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log.info("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, query)

df = self._execute_sql(query)
Expand Down
19 changes: 19 additions & 0 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ def _check_java_class_imported(cls, spark):
raise ValueError(msg) from e
return spark

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
with override_job_description(self.spark, f"{self}.check()"):
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))
self._query_on_executor(self._CHECK_QUERY, self.SQLOptions(fetchsize=1)).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def sql(
self,
Expand Down
18 changes: 0 additions & 18 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class JDBCMixin:
ExecuteOptions = JDBCExecuteOptions

DRIVER: ClassVar[str]
_CHECK_QUERY: ClassVar[str] = "SELECT 1"

@property
@abstractmethod
Expand Down Expand Up @@ -141,23 +140,6 @@ def __enter__(self):
def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: U101
self.close()

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def fetch(
self,
Expand Down
16 changes: 14 additions & 2 deletions onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_client_info, get_spark_version, stringify
from onetl._util.spark import (
get_client_info,
get_spark_version,
override_job_description,
stringify,
)
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.kafka.dialect import KafkaDialect
Expand Down Expand Up @@ -223,7 +228,14 @@ def check(self):
self._log_parameters()

try:
self._get_topics()
with override_job_description(self.spark, f"{self}.check()"):
self._get_topics()

read_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()}
# We need to read just any topic allowed to check if Kafka is alive
read_options["subscribePattern"] = ".*"
self.spark.read.format("kafka").options(**read_options).load().take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
13 changes: 12 additions & 1 deletion onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import warnings
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar
from urllib import parse as parser

from etl_entities.instance import Host
Expand Down Expand Up @@ -129,6 +129,9 @@ class MongoDB(DBConnection):
PipelineOptions = MongoDBPipelineOptions
Extra = MongoDBExtra

# small collection, always present on cluster
_CHECK_DUMMY_COLLECTION: ClassVar[str] = "admin.system.version"

@slot
@classmethod
def get_packages(
Expand Down Expand Up @@ -372,6 +375,14 @@ def check(self):
jvm = self.spark._jvm # type: ignore
client = jvm.com.mongodb.client.MongoClients.create(self.connection_url)
list(client.listDatabaseNames().iterator())

with override_job_description(self.spark, f"{self}.check()"):
read_options = {
"connection.uri": self.connection_url,
"collection": self._CHECK_DUMMY_COLLECTION,
}
self.spark.read.format("mongodb").options(**read_options).load().take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_spark_metrics_recorder_postgres_read(spark, processing, load_table_data

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert metrics.input.read_rows == rows
# +1 is just postgres.check()
assert metrics.input.read_rows == rows + 1
# JDBC does not provide information about data size
assert not metrics.input.read_bytes

Expand All @@ -64,7 +65,8 @@ def test_spark_metrics_recorder_postgres_read_empty_source(spark, processing, pr

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.input.read_rows
# 1 is just postgres.check()
assert metrics.input.read_rows == 1


def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, processing, load_table_data):
Expand All @@ -89,7 +91,8 @@ def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, proces

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.input.read_rows
# 1 is just postgres.check()
assert metrics.input.read_rows == 1


def test_spark_metrics_recorder_postgres_sql(spark, processing, load_table_data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onetl.db import DBReader
from tests.util.rand import rand_str

pytestmark = [pytest.mark.oracle, pytest.mark.flaky]
pytestmark = [pytest.mark.oracle, pytest.mark.flaky(reruns=3)]


def test_oracle_reader_snapshot(spark, processing, load_table_data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from onetl.connection import Oracle
from onetl.db import DBWriter

pytestmark = [pytest.mark.oracle, pytest.mark.flaky]
pytestmark = [pytest.mark.oracle, pytest.mark.flaky(reruns=3)]


def test_oracle_writer(spark, processing, prepare_schema_table):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onetl import __version__ as onetl_version
from onetl.connection import Oracle

pytestmark = [pytest.mark.oracle, pytest.mark.flaky]
pytestmark = [pytest.mark.oracle, pytest.mark.flaky(reruns=3)]


def test_oracle_connection_check(spark, processing, caplog):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_teradata_connection_check(spark, mocker, caplog):
password = secrets.token_hex()

mocker.patch.object(Teradata, "_query_optional_on_driver")
mocker.patch.object(Teradata, "_query_on_executor")

teradata = Teradata(
host=host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from onetl.db import DBReader
from onetl.strategy import IncrementalStrategy

pytestmark = [pytest.mark.oracle, pytest.mark.flaky]
pytestmark = [pytest.mark.oracle, pytest.mark.flaky(reruns=3)]


# There is no INTEGER column in Oracle, only NUMERIC
Expand Down

0 comments on commit 4d06b98

Please sign in to comment.