Skip to content

Commit

Permalink
Fix regression when passing lock_type= to sqlite atomic/transaction.
Browse files Browse the repository at this point in the history
Fixes #2071
  • Loading branch information
coleifer committed Dec 6, 2019
1 parent dcdf7cb commit 2b9f92c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
30 changes: 16 additions & 14 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,12 @@ class DatabaseProxy(Proxy):
"""
def connection_context(self):
return ConnectionContext(self)
def atomic(self, *args):
return _atomic(self, *args)
def atomic(self, *args, **kwargs):
return _atomic(self, *args, **kwargs)
def manual_commit(self):
return _manual(self)
def transaction(self, *args):
return _transaction(self, *args)
def transaction(self, *args, **kwargs):
return _transaction(self, *args, **kwargs)
def savepoint(self):
return _savepoint(self)

Expand Down Expand Up @@ -3188,14 +3188,14 @@ def top_transaction(self):
if self._state.transactions:
return self._state.transactions[-1]

def atomic(self, *args):
return _atomic(self, *args)
def atomic(self, *args, **kwargs):
return _atomic(self, *args, **kwargs)

def manual_commit(self):
return _manual(self)

def transaction(self, *args):
return _transaction(self, *args)
def transaction(self, *args, **kwargs):
return _transaction(self, *args, **kwargs)

def savepoint(self):
return _savepoint(self)
Expand Down Expand Up @@ -4065,13 +4065,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class _atomic(_callable_context_manager):
def __init__(self, db, *args):
def __init__(self, db, *args, **kwargs):
self.db = db
self._transaction_args = args
self._transaction_args = (args, kwargs)

def __enter__(self):
if self.db.transaction_depth() == 0:
self._helper = self.db.transaction(*self._transaction_args)
args, kwargs = self._transaction_args
self._helper = self.db.transaction(*args, **kwargs)
elif isinstance(self.db.top_transaction(), _manual):
raise ValueError('Cannot enter atomic commit block while in '
'manual commit mode.')
Expand All @@ -4084,12 +4085,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class _transaction(_callable_context_manager):
def __init__(self, db, *args):
def __init__(self, db, *args, **kwargs):
self.db = db
self._begin_args = args
self._begin_args = (args, kwargs)

def _begin(self):
self.db.begin(*self._begin_args)
args, kwargs = self._begin_args
self.db.begin(*args, **kwargs)

def commit(self, begin=True):
self.db.commit()
Expand Down
4 changes: 2 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def make_db_params(key):
handler.setLevel(logging.DEBUG)


def new_connection():
return db_loader(BACKEND, 'peewee_test')
def new_connection(**kwargs):
return db_loader(BACKEND, 'peewee_test', **kwargs)


db = new_connection()
Expand Down
31 changes: 31 additions & 0 deletions tests/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from .base import DatabaseTestCase
from .base import IS_CRDB
from .base import IS_SQLITE
from .base import ModelTestCase
from .base import db
from .base import new_connection
from .base import skip_if
from .base import skip_unless
from .base_models import Register


Expand Down Expand Up @@ -345,3 +348,31 @@ def test_session_commit(self):

self.assertTrue(db.session_rollback())
self.assertRegister([1, 2, 3])


@skip_unless(IS_SQLITE, 'requires sqlite for transaction lock type')
class TestTransactionLockType(BaseTransactionTestCase):
def test_lock_type(self):
db2 = new_connection(timeout=0.001)
db2.connect()

with self.database.atomic(lock_type='EXCLUSIVE') as txn:
with self.assertRaises(OperationalError):
with db2.atomic(lock_type='IMMEDIATE') as t2:
self._save(1)
self._save(2)
self.assertRegister([2])

with self.database.atomic('IMMEDIATE') as txn:
with self.assertRaises(OperationalError):
with db2.atomic('EXCLUSIVE') as t2:
self._save(3)
self._save(4)
self.assertRegister([2, 4])

with self.database.transaction(lock_type='DEFERRED') as txn:
self._save(5) # Deferred -> Exclusive after our write.
with self.assertRaises(OperationalError):
with db2.transaction(lock_type='IMMEDIATE') as t2:
self._save(6)
self.assertRegister([2, 4, 5])

0 comments on commit 2b9f92c

Please sign in to comment.