Skip to content

Commit

Permalink
[DOP-22425] Run DB.check() on both driver and executor
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Feb 19, 2025
1 parent 10e221c commit 18a298b
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 38 deletions.
2 changes: 2 additions & 0 deletions docs/changelog/next_release/346.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Now ``DB.check()`` will test connection availability not only on Spark driver, but also from some Spark executor.
This allows to fail immediately if Spark driver host has network access to target DB, but Spark executors haven't.
27 changes: 27 additions & 0 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_client_info,
get_executor_total_cores,
get_spark_version,
override_job_description,
)
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
Expand Down Expand Up @@ -185,6 +186,8 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338
CONNECTIONS_EXCEPTION_LIMIT: ClassVar[int] = 100

_CHECK_QUERY: ClassVar[str] = "SELECT 1"
# small table, always present on cluster
_CHECK_DUMMY_TABLE: ClassVar[str] = "pg_catalog.gp_id"

@slot
@classmethod
Expand Down Expand Up @@ -303,6 +306,30 @@ def jdbc_params(self) -> dict:
result.update(self.jdbc_custom_params)
return result

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
with override_job_description(self.spark, f"{self}.check()"):
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))

read_options = self._get_connector_params(self._CHECK_DUMMY_TABLE)
read_options["num_partitions"] = 1 # do not require gp_segment_id column in table
df = self.spark.read.format("greenplum").options(**read_options).load()
df.take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def read_source_as_df(
self,
Expand Down
4 changes: 2 additions & 2 deletions onetl/connection/db_connection/hive/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def check(self):

try:
with override_job_description(self.spark, f"{self}.check()"):
self._execute_sql(self._CHECK_QUERY).limit(1).collect()
self._execute_sql(self._CHECK_QUERY).take(1)
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down Expand Up @@ -379,7 +379,7 @@ def get_min_max_values(
hint=hint,
)

log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log.info("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, query)

df = self._execute_sql(query)
Expand Down
19 changes: 19 additions & 0 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ def _check_java_class_imported(cls, spark):
raise ValueError(msg) from e
return spark

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query:", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
with override_job_description(self.spark, f"{self}.check()"):
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))
self._query_on_executor(self._CHECK_QUERY, self.SQLOptions(fetchsize=1)).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def sql(
self,
Expand Down
18 changes: 0 additions & 18 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class JDBCMixin:
ExecuteOptions = JDBCExecuteOptions

DRIVER: ClassVar[str]
_CHECK_QUERY: ClassVar[str] = "SELECT 1"

@property
@abstractmethod
Expand Down Expand Up @@ -141,23 +140,6 @@ def __enter__(self):
def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: U101
self.close()

@slot
def check(self):
log.info("|%s| Checking connection availability...", self.__class__.__name__)
self._log_parameters() # type: ignore

log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__)
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)

try:
self._query_optional_on_driver(self._CHECK_QUERY, self.FetchOptions(fetchsize=1))
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
raise RuntimeError("Connection is unavailable") from e

return self

@slot
def fetch(
self,
Expand Down
16 changes: 14 additions & 2 deletions onetl/connection/db_connection/kafka/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_client_info, get_spark_version, stringify
from onetl._util.spark import (
get_client_info,
get_spark_version,
override_job_description,
stringify,
)
from onetl._util.version import Version
from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.kafka.dialect import KafkaDialect
Expand Down Expand Up @@ -223,7 +228,14 @@ def check(self):
self._log_parameters()

try:
self._get_topics()
with override_job_description(self.spark, f"{self}.check()"):
self._get_topics()

read_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()}
# We need to read just any topic allowed to check if Kafka is alive
read_options["subscribePattern"] = ".*"
self.spark.read.format("kafka").options(**read_options).load().take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
13 changes: 12 additions & 1 deletion onetl/connection/db_connection/mongodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import warnings
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar
from urllib import parse as parser

from etl_entities.instance import Host
Expand Down Expand Up @@ -129,6 +129,9 @@ class MongoDB(DBConnection):
PipelineOptions = MongoDBPipelineOptions
Extra = MongoDBExtra

# small collection, always present on cluster
_CHECK_DUMMY_COLLECTION: ClassVar[str] = "admin.system.version"

@slot
@classmethod
def get_packages(
Expand Down Expand Up @@ -372,6 +375,14 @@ def check(self):
jvm = self.spark._jvm # type: ignore
client = jvm.com.mongodb.client.MongoClients.create(self.connection_url)
list(client.listDatabaseNames().iterator())

with override_job_description(self.spark, f"{self}.check()"):
read_options = {
"connection.uri": self.connection_url,
"collection": self._CHECK_DUMMY_COLLECTION,
}
self.spark.read.format("mongodb").options(**read_options).load().take(1)

log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
Expand Down
2 changes: 1 addition & 1 deletion requirements/tests/oracle.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cx_Oracle
oracledb
22 changes: 11 additions & 11 deletions tests/fixtures/processing/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from logging import getLogger
from urllib.parse import quote

import cx_Oracle
import oracledb
import pandas
from pandas.io import sql as psql

Expand Down Expand Up @@ -69,21 +69,21 @@ def schema(self) -> str:

@property
def url(self) -> str:
dsn = cx_Oracle.makedsn(self.host, self.port, sid=self.sid, service_name=self.service_name)
return f"oracle://{self.user}:{quote(self.password)}@{dsn}"
dsn = oracledb.makedsn(self.host, self.port, sid=self.sid, service_name=self.service_name)
return f"oracle+oracledb://{self.user}:{quote(self.password)}@{dsn}"

def get_dsn(self) -> cx_Oracle.Dsn:
def get_dsn(self) -> str:
try:
cx_Oracle.init_oracle_client(lib_dir=os.getenv("ONETL_ORA_CLIENT_PATH"))
oracledb.init_oracle_client(lib_dir=os.getenv("ONETL_ORA_CLIENT_PATH"))
except Exception:
logger.debug("cx_Oracle client is already initialized.", exc_info=True)
return cx_Oracle.makedsn(self.host, self.port, sid=self.sid, service_name=self.service_name)
logger.debug("Oracle client is already initialized.", exc_info=True)
return oracledb.makedsn(self.host, self.port, sid=self.sid, service_name=self.service_name)

def get_conn(self) -> cx_Oracle.Connection:
return cx_Oracle.connect(user=self.user, password=self.password, dsn=self.get_dsn())
def get_conn(self) -> oracledb.Connection:
return oracledb.connect(user=self.user, password=self.password, dsn=self.get_dsn())

def get_root_conn(self) -> cx_Oracle.Connection:
return cx_Oracle.connect(user=self.root_user, password=self.root_password, dsn=self.get_dsn())
def get_root_conn(self) -> oracledb.Connection:
return oracledb.connect(user=self.root_user, password=self.root_password, dsn=self.get_dsn())

def create_schema_ddl(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_spark_metrics_recorder_postgres_read(spark, processing, load_table_data

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert metrics.input.read_rows == rows
# +1 is just postgres.check()
assert metrics.input.read_rows == rows + 1
# JDBC does not provide information about data size
assert not metrics.input.read_bytes

Expand All @@ -64,7 +65,8 @@ def test_spark_metrics_recorder_postgres_read_empty_source(spark, processing, pr

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.input.read_rows
# 1 is just postgres.check()
assert metrics.input.read_rows == 1


def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, processing, load_table_data):
Expand All @@ -89,7 +91,8 @@ def test_spark_metrics_recorder_postgres_read_no_data_after_filter(spark, proces

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.input.read_rows
# 1 is just postgres.check()
assert metrics.input.read_rows == 1


def test_spark_metrics_recorder_postgres_sql(spark, processing, load_table_data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_teradata_connection_check(spark, mocker, caplog):
password = secrets.token_hex()

mocker.patch.object(Teradata, "_query_optional_on_driver")
mocker.patch.object(Teradata, "_query_on_executor")

teradata = Teradata(
host=host,
Expand Down

0 comments on commit 18a298b

Please sign in to comment.