diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py index 229e9ceea..f3a4c30fc 100644 --- a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py +++ b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres/base.py @@ -122,6 +122,17 @@ def __init__(self, enforce_collate=None, **kwargs): type_compiler = BICustomPGTypeCompiler statement_compiler = BIPGCompilerBasic ischema_names = bi_pg_ischema_names + forced_server_version_string: str | None = None + + def connect(self, *cargs, **cparams): + self.forced_server_version_string = cparams.pop("server_version", self.forced_server_version_string) + return super().connect(*cargs, **cparams) + + def _get_server_version_info(self, connection) -> tuple[int, ...]: + if self.forced_server_version_string is not None: + return tuple(int(part) for part in self.forced_server_version_string.split(".")) + + return super()._get_server_version_info(connection) class BIPGDialect(BIPGDialectBasic): diff --git a/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/unit/test_server_version.py b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/unit/test_server_version.py new file mode 100644 index 000000000..959fafe8a --- /dev/null +++ b/lib/dl_sqlalchemy_postgres/dl_sqlalchemy_postgres_tests/unit/test_server_version.py @@ -0,0 +1,16 @@ +import sqlalchemy +import sqlalchemy.orm as sqlalchemy_orm + + +SERVER_VERSION_INFO = (123, 45, 67, 89) +SERVER_VERSION = ".".join(map(str, SERVER_VERSION_INFO)) + + +def test_server_version(engine_url: str): + engine = sqlalchemy.create_engine(engine_url, connect_args=dict(server_version=SERVER_VERSION)) + session_maker = sqlalchemy_orm.sessionmaker(bind=engine) + engine_session = session_maker() + # Connection and version get initialized on first query: + engine_session.scalar("select 1") + ver = engine_session.get_bind().dialect.server_version_info + assert ver == SERVER_VERSION_INFO