diff --git a/src/porm/databases/api/__init__.py b/src/porm/databases/api/__init__.py index b8a7183..e853110 100644 --- a/src/porm/databases/api/__init__.py +++ b/src/porm/databases/api/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import threading import uuid @@ -5,6 +7,7 @@ from functools import wraps from typing import List, Dict +from porm.databases.api.drivers import mysql_constants from porm.errors import InterfaceError, OperationalError, __exception_wrapper__ try: # Python 2.7+ @@ -89,7 +92,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _transaction(_callable_context_manager): - def __init__(self, db, lock_type=None): + def __init__(self, db: DBApi, lock_type=None): self.db = db self._lock_type = lock_type @@ -354,7 +357,15 @@ def begin(self, _lock_type=None): self.connect() if _lock_type: # do with lock type - self._state.conn.begin() + self._begin(pessimistic_lock=True) + else: + self._begin() + + def _begin(self, pessimistic_lock=False): + if pessimistic_lock: + # suit for tidb + self._state.conn._execute_command(mysql_constants.COMMAND.COM_QUERY, "BEGIN /*!90000 PESSIMISTIC */") + self._state.conn._read_ok_packet() else: self._state.conn.begin() @@ -412,10 +423,10 @@ def execute_sql(self, sql, params: tuple = None, commit=SENTINEL): cursor = self.cursor(commit) try: cursor.execute(sql, params or ()) - except Exception: + except Exception as ex: if self.autorollback and not self.in_transaction(): self.rollback() - raise + raise ex else: if commit and not self.in_transaction(): self.commit() @@ -433,10 +444,10 @@ def execute_sqls(self, sql, params: List[tuple] = None, commit=SENTINEL): cursor = self.cursor(commit) try: cursor.executemany(sql, params or []) - except Exception: + except Exception as ex: if self.autorollback and not self.in_transaction(): self.rollback() - raise + raise ex else: if commit and not self.in_transaction(): self.commit() diff --git a/src/porm/databases/api/drivers/__init__.py b/src/porm/databases/api/drivers/__init__.py index b236aca..6d56598 100644 --- a/src/porm/databases/api/drivers/__init__.py +++ b/src/porm/databases/api/drivers/__init__.py @@ -1,12 +1,18 @@ +from typing import Union + +import pymysql + __all__ = ( 'mysql' ) mysql_passwd = False -mysql = None +mysql: pymysql = None +mysql_constants: pymysql.constants = None try: import pymysql as mysql + from pymysql import constants as mysql_constants except ImportError: try: import MySQLdb as mysql diff --git a/src/porm/databases/api/mysql.py b/src/porm/databases/api/mysql.py index dd7dd6c..92e5218 100644 --- a/src/porm/databases/api/mysql.py +++ b/src/porm/databases/api/mysql.py @@ -1,8 +1,6 @@ import contextlib import logging -import pymysql - from porm.databases.api import DBApi, _transaction from porm.databases.api.drivers import mysql as driver from porm.errors import EmptyError @@ -23,7 +21,7 @@ def emit(self, record): 'db': 'PORM_DATABASE', 'charset': 'utf8', 'autocommit': 0, # default 0 - 'cursorclass': pymysql.cursors.DictCursor + 'cursorclass': driver.cursors.DictCursor } diff --git a/src/porm/orms/mysql.py b/src/porm/orms/mysql.py index 2816809..4cc9221 100644 --- a/src/porm/orms/mysql.py +++ b/src/porm/orms/mysql.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from collections import defaultdict from enum import Enum, unique -from functools import partial from typing import Union, Dict from porm import BaseType, VarcharType @@ -277,7 +275,11 @@ def join(self, base_table, join_table, **eq_terms) -> Join: def to_json(self) -> Dict[str, Dict[str, tuple]]: """ - :return: {'join_tablename': {'base_tablename.field1': ('value', 'LIKE'), 'base_tablename.field1': ('\\join_tablename.field2\\', '=')}} + :return: { + 'join_tablename': { + 'base_tablename.field1': ('value', 'LIKE'), 'base_tablename.field1': ('\\join_tablename.field2\\', '=') + } + } """ return self._join_terms.copy() diff --git a/tests/test_model.py b/tests/test_model.py index 938c678..110ebca 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -190,6 +190,20 @@ def test_06_transaction_failed(self): pass self.assertIsNone(UserInfo.get_one(email='312dennias.chiu@gmail.com')) + def test_07_duplicate_key(self): + ui = UserInfo.new( + email='dennias.chiu@gmail.com1', username='dennias', height=180, + properties={"yooyo": "hahaha"}) + try: + with ui.start_transaction() as _t: + ui.insert(t=_t) + # test tidb pessimic Pessimistic Lock + self.assertTrue(False) + except Exception as ex: + self.assertEqual(type(ex), pymysql.err.IntegrityError) + self.assertEqual(str(ex), """(1062, "Duplicate entry 'dennias.chiu@gmail.com1' for key 'email'")""") + self.assertIsNone(UserInfo.get_one(email='312dennias.chiu@gmail.com')) + def test_99_drop_table(self): with UserInfo.start_transaction() as _t: UserInfo.drop(t=_t)