Skip to content

Commit

Permalink
Refactoring server version introspection.
Browse files Browse the repository at this point in the history
Also should fix #1876 which reports incompatibility in the handling of
MySQLdb server version.
  • Loading branch information
coleifer committed Mar 6, 2019
1 parent a081eb0 commit 4178b42
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
35 changes: 22 additions & 13 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,7 @@ class Database(_callable_context_manager):
operations = {}
param = '?'
quote = '""'
server_version = None

# Feature toggles.
commit_select = False
Expand Down Expand Up @@ -2775,11 +2776,16 @@ def connect(self, reuse_if_open=False):
with __exception_wrapper__:
self._state.set_connection(self._connect())
self._initialize_connection(self._state.conn)
if self.server_version is None:
self._set_server_version(self._state.conn)
return True

def _initialize_connection(self, conn):
pass

def _set_server_version(self, conn):
self.server_version = 0

def close(self):
with self._lock:
if self.deferred:
Expand Down Expand Up @@ -3030,7 +3036,7 @@ class SqliteDatabase(Database):
'ILIKE': 'LIKE'}
index_schema_prefix = True
limit_max = -1
_sqlite_version = __sqlite_version__
server_version = __sqlite_version__

def __init__(self, database, *args, **kwargs):
self._pragmas = kwargs.pop('pragmas', ())
Expand All @@ -3053,6 +3059,9 @@ def init(self, database, pragmas=None, timeout=5, **kwargs):
self._timeout = timeout
super(SqliteDatabase, self).init(database, **kwargs)

def _set_server_version(self, conn):
pass

def _connect(self):
if sqlite3 is None:
raise ImproperlyConfigured('SQLite driver not installed!')
Expand All @@ -3074,7 +3083,7 @@ def _add_conn_hooks(self, conn):
self._load_aggregates(conn)
self._load_collations(conn)
self._load_functions(conn)
if self._sqlite_version >= (3, 25, 0):
if self.server_version >= (3, 25, 0):
self._load_window_functions(conn)
if self._table_functions:
for table_function in self._table_functions:
Expand Down Expand Up @@ -3351,7 +3360,7 @@ def conflict_statement(self, on_conflict, query):

def conflict_update(self, oc, query):
# Sqlite prior to 3.24.0 does not support Postgres-style upsert.
if self._sqlite_version < (3, 24, 0) and \
if self.server_version < (3, 24, 0) and \
any((oc._preserve, oc._update, oc._where, oc._conflict_target,
oc._conflict_constraint)):
raise ValueError('SQLite does not support specifying which values '
Expand Down Expand Up @@ -3408,7 +3417,6 @@ class PostgresqlDatabase(Database):
def init(self, database, register_unicode=True, encoding=None, **kwargs):
self._register_unicode = register_unicode
self._encoding = encoding
self._need_server_version = True
super(PostgresqlDatabase, self).init(database, **kwargs)

def _connect(self):
Expand All @@ -3420,13 +3428,11 @@ def _connect(self):
pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
if self._encoding:
conn.set_client_encoding(self._encoding)
if self._need_server_version:
self.set_server_version(conn.server_version)
self._need_server_version = False
return conn

def set_server_version(self, version):
if version >= 90600:
def _set_server_version(self, conn):
self.server_version = conn.server_version
if self.server_version >= 90600:
self.safe_create_index = True

def last_insert_id(self, cursor, query_type=None):
Expand Down Expand Up @@ -3580,7 +3586,6 @@ class MySQLDatabase(Database):
limit_max = 2 ** 64 - 1
safe_create_index = False
safe_drop_index = False
_server_version = None

def init(self, database, **kwargs):
params = {'charset': 'utf8', 'use_unicode': True}
Expand All @@ -3593,11 +3598,15 @@ def _connect(self):
if mysql is None:
raise ImproperlyConfigured('MySQL driver not installed!')
conn = mysql.connect(db=self.database, **self.connect_params)
if self._server_version is None:
version_raw = conn.server_version
self._server_version = self._extract_server_version(version_raw)
return conn

def _set_server_version(self, conn):
try:
version_raw = conn.server_version
except AttributeError:
version_raw = conn.get_server_info()
self.server_version = self._extract_server_version(version_raw)

def _extract_server_version(self, version):
version = version.lower()
if 'maria' in version:
Expand Down
2 changes: 1 addition & 1 deletion playhouse/apsw_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class APSWDatabase(SqliteExtDatabase):
_sqlite_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))
server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))

def __init__(self, database, **kwargs):
self._modules = {}
Expand Down

0 comments on commit 4178b42

Please sign in to comment.