Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sqlalchemy server version info without request #139

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class PostgreSQLDataSourceMixin(BaseSQLDataSource):
def is_compatible_with_type(cls, source_type: DataSourceType) -> bool:
return source_type in (SOURCE_TYPE_PG_TABLE, SOURCE_TYPE_PG_SUBSELECT)

def get_connect_args(self) -> dict:
return dict(super().get_connect_args(), server_version=self.db_version)


class PostgreSQLDataSource(PostgreSQLDataSourceMixin, StandardSchemaSQLDataSource):
"""PG table"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ class PostgresAdapter(BasePostgresAdapter, BaseClassicAdapter[PostgresConnTarget
}

def get_connect_args(self) -> dict:
return {
"sslmode": "require" if self._target_dto.ssl_enable else "prefer",
"sslrootcert": self.get_ssl_cert_path(self._target_dto.ssl_ca),
}
return dict(
super().get_connect_args(),
sslmode="require" if self._target_dto.ssl_enable else "prefer",
sslrootcert=self.get_ssl_cert_path(self._target_dto.ssl_ca),
)

def get_engine_kwargs(self) -> dict:
result: Dict = {}
Expand Down

This file was deleted.

21 changes: 21 additions & 0 deletions lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import datetime
import logging
import typing
from typing import (
Any,
Optional,
Expand All @@ -13,6 +15,9 @@
from dl_sqlalchemy_common.base import CompilerPrettyMixin


LOGGER = logging.getLogger(__name__)


class CITEXT(sqltypes.TEXT):

"""Provide the PostgreSQL CITEXT type.
Expand Down Expand Up @@ -122,6 +127,22 @@ def __init__(self, enforce_collate=None, **kwargs):
type_compiler = BICustomPGTypeCompiler
statement_compiler = BIPGCompilerBasic
ischema_names = bi_pg_ischema_names
forced_server_version_string: str | None = None
error_server_version_info: tuple[int, ...] = (9, 3)

def connect(self, *cargs: typing.Any, **cparams: typing.Any):
self.forced_server_version_string = cparams.pop("server_version", self.forced_server_version_string)
return super().connect(*cargs, **cparams)

def _get_server_version_info(self, connection) -> tuple[int, ...]:
if self.forced_server_version_string is not None:
return tuple(int(part) for part in self.forced_server_version_string.split("."))

try:
return super()._get_server_version_info(connection)
except Exception:
LOGGER.exception("Failed to get server version info, assuming %s", self.error_server_version_info)
return self.error_server_version_info


class BIPGDialect(BIPGDialectBasic):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import sqlalchemy
import sqlalchemy.engine as sqlalchemy_engine
import sqlalchemy.orm as sqlalchemy_orm

from dl_testing.containers import get_test_container_hostport
from dl_testing.utils import wait_for_port


@pytest.fixture(scope="session")
def engine_url() -> str:
db_postgres_hostport = get_test_container_hostport("db-postgres", fallback_port=52301)
wait_for_port(db_postgres_hostport.host, db_postgres_hostport.port)
return f"bi_postgresql://datalens:qwerty@{db_postgres_hostport.as_pair()}/test_data"


@pytest.fixture
def sa_engine(engine_url: str) -> sqlalchemy_engine.Engine:
return sqlalchemy.create_engine(engine_url)


@pytest.fixture
def sa_session_maker(sa_engine) -> sqlalchemy_orm.sessionmaker:
return sqlalchemy_orm.sessionmaker(bind=sa_engine)


@pytest.fixture
def sa_session(sa_session_maker) -> sqlalchemy_orm.Session:
return sa_session_maker()
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import datetime

import pytest
import pytz
import sqlalchemy
import sqlalchemy.engine as sqlalchemy_engine
import sqlalchemy.orm as sqlalchemy_orm


@pytest.mark.parametrize(
"timezone",
(
None,
datetime.timezone.utc,
pytz.timezone("America/New_York"),
),
ids=["no_timezone", "utc_timezone", "ny_timezone"],
)
@pytest.mark.parametrize("microseconds", (0, 123356), ids=["no_microseconds", "with_microseconds"])
def test_pg_literal_bind_datetimes(
timezone: datetime.timezone | None,
microseconds: int,
sa_engine: sqlalchemy_engine.Engine,
sa_session: sqlalchemy_orm.Session,
):
value = datetime.datetime(2020, 1, 1, 3, 4, 5, microseconds).replace(tzinfo=timezone)

query = sqlalchemy.select([sqlalchemy.literal(value)])
compiled = str(query.compile(dialect=sa_engine.dialect, compile_kwargs={"literal_binds": True}))
res_direct = list(sa_session.execute(query))
res_literal = list(sa_session.execute(compiled))
assert res_direct == res_literal, dict(literal_query=compiled)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import mock
import sqlalchemy
import sqlalchemy.dialects.postgresql.psycopg2 as sqlalchemy_dialect_psycopg2
import sqlalchemy.orm as sqlalchemy_orm


SERVER_VERSION_INFO = (123, 45, 67, 89)
SERVER_VERSION = ".".join(map(str, SERVER_VERSION_INFO))


@mock.patch.object(sqlalchemy_dialect_psycopg2.PGDialect_psycopg2, "_get_server_version_info")
def test_server_version_default(patched_server_version: mock.Mock, sa_session: sqlalchemy_orm.Session):
patched_server_version.return_value = SERVER_VERSION_INFO

sa_session.scalar("select 1")

assert sa_session.get_bind().dialect.server_version_info == SERVER_VERSION_INFO
patched_server_version.assert_called_once()


@mock.patch.object(sqlalchemy_dialect_psycopg2.PGDialect_psycopg2, "_get_server_version_info")
def test_server_version_error(patched_server_version: mock.Mock, sa_session: sqlalchemy_orm.Session):
patched_server_version.side_effect = AssertionError

sa_session.scalar("select 1")

assert sa_session.get_bind().dialect.server_version_info == (9, 3)
patched_server_version.assert_called_once()


@mock.patch.object(sqlalchemy_dialect_psycopg2.PGDialect_psycopg2, "_get_server_version_info")
def test_server_version_overwritten(patched_server_version: mock.Mock, engine_url: str):
sa_engine = sqlalchemy.create_engine(engine_url, connect_args=dict(server_version=SERVER_VERSION))
sa_session_maker = sqlalchemy_orm.sessionmaker(bind=sa_engine)
sa_session = sa_session_maker()

sa_session.scalar("select 1")

assert sa_session.get_bind().dialect.server_version_info == SERVER_VERSION_INFO
patched_server_version.assert_not_called()
11 changes: 11 additions & 0 deletions lib/dl_sqlalchemy_postgres/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: '3.7'

services:
db-postgres:
image: "postgres:13-alpine@sha256:b9f66c57932510574fb17bccd175776535cec9abcfe7ba306315af2f0b7bfbb4"
environment:
- POSTGRES_DB=test_data
- POSTGRES_USER=datalens
- POSTGRES_PASSWORD=qwerty
ports:
- "50319:5432"
7 changes: 4 additions & 3 deletions lib/dl_sqlalchemy_postgres/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra"
testpaths = ["dl_sqlalchemy_postgres_tests/unit"]


testpaths = [
"dl_sqlalchemy_postgres_tests/unit",
"dl_sqlalchemy_postgres_tests/db",
]

[tool.mypy]
warn_unused_configs = true
Expand Down
19 changes: 19 additions & 0 deletions lib/dl_testing/dl_testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import http.client
import logging
import os
import socket
import time
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -57,6 +59,23 @@ def check_initdb_liveness() -> Tuple[bool, Any]:
)


def wait_for_port(host: str, port: int, period_seconds: int = 1, timeout_seconds: int = 10):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
time_start = time.time()

while time.time() - time_start < timeout_seconds:
try:
sock.connect((host, port))
sock.close()
LOGGER.info(f"{host}:{port} is available")
return
except socket.error:
LOGGER.warning(f"Waiting for {host}:{port} to become available")
time.sleep(period_seconds)

raise Exception(f"Timeout waiting for {host}:{port} to become available")


@overload
def get_log_record(
caplog: Any, predicate: Callable[[logging.LogRecord], bool], single: Literal[True] = True
Expand Down
Loading