diff --git a/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql/data_source.py b/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql/data_source.py index b4f1ed35e..c088cd0b9 100644 --- a/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql/data_source.py +++ b/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql/data_source.py @@ -38,6 +38,9 @@ class PostgreSQLDataSourceMixin(BaseSQLDataSource): def is_compatible_with_type(cls, source_type: DataSourceType) -> bool: return source_type in (SOURCE_TYPE_PG_TABLE, SOURCE_TYPE_PG_SUBSELECT) + def get_connect_args(self) -> dict: + return dict(super().get_connect_args(), server_version=self.db_version) + class PostgreSQLDataSource(PostgreSQLDataSourceMixin, StandardSchemaSQLDataSource): """PG table""" diff --git a/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/adapters_postgres.py b/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/adapters_postgres.py index 2cb3c0eef..9129ad70d 100644 --- a/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/adapters_postgres.py +++ b/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/adapters_postgres.py @@ -41,10 +41,11 @@ class PostgresAdapter(BasePostgresAdapter, BaseClassicAdapter[PostgresConnTarget } def get_connect_args(self) -> dict: - return { - "sslmode": "require" if self._target_dto.ssl_enable else "prefer", - "sslrootcert": self.get_ssl_cert_path(self._target_dto.ssl_ca), - } + return dict( + super().get_connect_args(), + sslmode="require" if self._target_dto.ssl_enable else "prefer", + sslrootcert=self.get_ssl_cert_path(self._target_dto.ssl_ca), + ) def get_engine_kwargs(self) -> dict: result: Dict = {} diff --git a/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_sa_dialect.py b/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_sa_dialect.py deleted file mode 100644 index 96313a9fc..000000000 --- a/lib/dl_connector_postgresql/dl_connector_postgresql_tests/db/core/test_sa_dialect.py +++ /dev/null @@ -1,40 +0,0 @@ -import datetime - -import pytest -import pytz -import sqlalchemy as sa - -from dl_connector_postgresql_tests.db.core.base import BasePostgreSQLTestClass - - -TEST_VALUES = [datetime.date(2020, 1, 1)] + [ - datetime.datetime(2020, idx1 + 1, idx2 + 1, 3, 4, 5, us).replace(tzinfo=tzinfo) - for idx1, us in enumerate((0, 123356)) - for idx2, tzinfo in enumerate( - ( - None, - datetime.timezone.utc, - pytz.timezone("America/New_York"), - ) - ) -] - - -class TestPostgresqlSaDialect(BasePostgreSQLTestClass): - @pytest.mark.parametrize("value", TEST_VALUES, ids=[val.isoformat() for val in TEST_VALUES]) - def test_pg_literal_bind_datetimes(self, value, db): - """ - Test that query results for literal_binds matches the query results without, - for the custom dialect code. - - This test should be in the bi_postgresql dialect itself, but it doesn't have - a postgres-using test at the moment. - """ - execute = db.execute - dialect = db._engine_wrapper.dialect - - query = sa.select([sa.literal(value)]) - compiled = str(query.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) - res_direct = list(execute(query)) - res_literal = list(execute(compiled)) - assert res_direct == res_literal, dict(literal_query=compiled) diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py index 229e9ceea..e6279c7ef 100644 --- a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py +++ b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py @@ -1,6 +1,8 @@ from __future__ import annotations import datetime +import logging +import typing from typing import ( Any, Optional, @@ -13,6 +15,9 @@ from dl_sqlalchemy_common.base import CompilerPrettyMixin +LOGGER = logging.getLogger(__name__) + + class CITEXT(sqltypes.TEXT): """Provide the PostgreSQL CITEXT type. @@ -122,6 +127,22 @@ def __init__(self, enforce_collate=None, **kwargs): type_compiler = BICustomPGTypeCompiler statement_compiler = BIPGCompilerBasic ischema_names = bi_pg_ischema_names + forced_server_version_string: str | None = None + error_server_version_info: tuple[int, ...] = (9, 3) + + def connect(self, *cargs: typing.Any, **cparams: typing.Any): + self.forced_server_version_string = cparams.pop("server_version", self.forced_server_version_string) + return super().connect(*cargs, **cparams) + + def _get_server_version_info(self, connection) -> tuple[int, ...]: + if self.forced_server_version_string is not None: + return tuple(int(part) for part in self.forced_server_version_string.split(".")) + + try: + return super()._get_server_version_info(connection) + except Exception: + LOGGER.exception("Failed to get server version info, assuming %s", self.error_server_version_info) + return self.error_server_version_info class BIPGDialect(BIPGDialectBasic): diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/__init__.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/conftest.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/conftest.py new file mode 100644 index 000000000..b43b1479b --- /dev/null +++ b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/conftest.py @@ -0,0 +1,29 @@ +import pytest +import sqlalchemy +import sqlalchemy.engine as sqlalchemy_engine +import sqlalchemy.orm as sqlalchemy_orm + +from dl_testing.containers import get_test_container_hostport +from dl_testing.utils import wait_for_port + + +@pytest.fixture(scope="session") +def engine_url() -> str: + db_postgres_hostport = get_test_container_hostport("db-postgres", fallback_port=52301) + wait_for_port(db_postgres_hostport.host, db_postgres_hostport.port) + return f"bi_postgresql://datalens:qwerty@{db_postgres_hostport.as_pair()}/test_data" + + +@pytest.fixture +def sa_engine(engine_url: str) -> sqlalchemy_engine.Engine: + return sqlalchemy.create_engine(engine_url) + + +@pytest.fixture +def sa_session_maker(sa_engine) -> sqlalchemy_orm.sessionmaker: + return sqlalchemy_orm.sessionmaker(bind=sa_engine) + + +@pytest.fixture +def sa_session(sa_session_maker) -> sqlalchemy_orm.Session: + return sa_session_maker() diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/test_pg_literal_bind.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/test_pg_literal_bind.py new file mode 100644 index 000000000..9b5be2fe1 --- /dev/null +++ b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/test_pg_literal_bind.py @@ -0,0 +1,32 @@ +import datetime + +import pytest +import pytz +import sqlalchemy +import sqlalchemy.engine as sqlalchemy_engine +import sqlalchemy.orm as sqlalchemy_orm + + +@pytest.mark.parametrize( + "timezone", + ( + None, + datetime.timezone.utc, + pytz.timezone("America/New_York"), + ), + ids=["no_timezone", "utc_timezone", "ny_timezone"], +) +@pytest.mark.parametrize("microseconds", (0, 123356), ids=["no_microseconds", "with_microseconds"]) +def test_pg_literal_bind_datetimes( + timezone: datetime.timezone | None, + microseconds: int, + sa_engine: sqlalchemy_engine.Engine, + sa_session: sqlalchemy_orm.Session, +): + value = datetime.datetime(2020, 1, 1, 3, 4, 5, microseconds).replace(tzinfo=timezone) + + query = sqlalchemy.select([sqlalchemy.literal(value)]) + compiled = str(query.compile(dialect=sa_engine.dialect, compile_kwargs={"literal_binds": True})) + res_direct = list(sa_session.execute(query)) + res_literal = list(sa_session.execute(compiled)) + assert res_direct == res_literal, dict(literal_query=compiled) diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/test_server_version.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/test_server_version.py new file mode 100644 index 000000000..122f938df --- /dev/null +++ b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/db/test_server_version.py @@ -0,0 +1,40 @@ +import mock +import sqlalchemy +import sqlalchemy.dialects.postgresql.psycopg2 as sqlalchemy_dialect_psycopg2 +import sqlalchemy.orm as sqlalchemy_orm + + +SERVER_VERSION_INFO = (123, 45, 67, 89) +SERVER_VERSION = ".".join(map(str, SERVER_VERSION_INFO)) + + +@mock.patch.object(sqlalchemy_dialect_psycopg2.PGDialect_psycopg2, "_get_server_version_info") +def test_server_version_default(patched_server_version: mock.Mock, sa_session: sqlalchemy_orm.Session): + patched_server_version.return_value = SERVER_VERSION_INFO + + sa_session.scalar("select 1") + + assert sa_session.get_bind().dialect.server_version_info == SERVER_VERSION_INFO + patched_server_version.assert_called_once() + + +@mock.patch.object(sqlalchemy_dialect_psycopg2.PGDialect_psycopg2, "_get_server_version_info") +def test_server_version_error(patched_server_version: mock.Mock, sa_session: sqlalchemy_orm.Session): + patched_server_version.side_effect = AssertionError + + sa_session.scalar("select 1") + + assert sa_session.get_bind().dialect.server_version_info == (9, 3) + patched_server_version.assert_called_once() + + +@mock.patch.object(sqlalchemy_dialect_psycopg2.PGDialect_psycopg2, "_get_server_version_info") +def test_server_version_overwritten(patched_server_version: mock.Mock, engine_url: str): + sa_engine = sqlalchemy.create_engine(engine_url, connect_args=dict(server_version=SERVER_VERSION)) + sa_session_maker = sqlalchemy_orm.sessionmaker(bind=sa_engine) + sa_session = sa_session_maker() + + sa_session.scalar("select 1") + + assert sa_session.get_bind().dialect.server_version_info == SERVER_VERSION_INFO + patched_server_version.assert_not_called() diff --git a/lib/dl_sqlalchemy_postgres/docker-compose.yml b/lib/dl_sqlalchemy_postgres/docker-compose.yml new file mode 100644 index 000000000..6cf281e41 --- /dev/null +++ b/lib/dl_sqlalchemy_postgres/docker-compose.yml @@ -0,0 +1,11 @@ +version: '3.7' + +services: + db-postgres: + image: "postgres:13-alpine@sha256:b9f66c57932510574fb17bccd175776535cec9abcfe7ba306315af2f0b7bfbb4" + environment: + - POSTGRES_DB=test_data + - POSTGRES_USER=datalens + - POSTGRES_PASSWORD=qwerty + ports: + - "50319:5432" diff --git a/lib/dl_sqlalchemy_postgres/pyproject.toml b/lib/dl_sqlalchemy_postgres/pyproject.toml index 06342858c..ca001f962 100644 --- a/lib/dl_sqlalchemy_postgres/pyproject.toml +++ b/lib/dl_sqlalchemy_postgres/pyproject.toml @@ -28,9 +28,10 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] minversion = "6.0" addopts = "-ra" -testpaths = ["dl_sqlalchemy_postgres_tests/unit"] - - +testpaths = [ + "dl_sqlalchemy_postgres_tests/unit", + "dl_sqlalchemy_postgres_tests/db", +] [tool.mypy] warn_unused_configs = true diff --git a/lib/dl_testing/dl_testing/utils.py b/lib/dl_testing/dl_testing/utils.py index af4b9c2dc..054f6c22b 100644 --- a/lib/dl_testing/dl_testing/utils.py +++ b/lib/dl_testing/dl_testing/utils.py @@ -4,6 +4,8 @@ import http.client import logging import os +import socket +import time from typing import ( Any, Callable, @@ -57,6 +59,23 @@ def check_initdb_liveness() -> Tuple[bool, Any]: ) +def wait_for_port(host: str, port: int, period_seconds: int = 1, timeout_seconds: int = 10): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + time_start = time.time() + + while time.time() - time_start < timeout_seconds: + try: + sock.connect((host, port)) + sock.close() + LOGGER.info(f"{host}:{port} is available") + return + except socket.error: + LOGGER.warning(f"Waiting for {host}:{port} to become available") + time.sleep(period_seconds) + + raise Exception(f"Timeout waiting for {host}:{port} to become available") + + @overload def get_log_record( caplog: Any, predicate: Callable[[logging.LogRecord], bool], single: Literal[True] = True