diff --git a/peewee.py b/peewee.py index 8cfcbb57a..06bd9e115 100644 --- a/peewee.py +++ b/peewee.py @@ -452,12 +452,12 @@ class DatabaseProxy(Proxy): """ def connection_context(self): return ConnectionContext(self) - def atomic(self, *args): - return _atomic(self, *args) + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) def manual_commit(self): return _manual(self) - def transaction(self, *args): - return _transaction(self, *args) + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) def savepoint(self): return _savepoint(self) @@ -3188,14 +3188,14 @@ def top_transaction(self): if self._state.transactions: return self._state.transactions[-1] - def atomic(self, *args): - return _atomic(self, *args) + def atomic(self, *args, **kwargs): + return _atomic(self, *args, **kwargs) def manual_commit(self): return _manual(self) - def transaction(self, *args): - return _transaction(self, *args) + def transaction(self, *args, **kwargs): + return _transaction(self, *args, **kwargs) def savepoint(self): return _savepoint(self) @@ -4065,13 +4065,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _atomic(_callable_context_manager): - def __init__(self, db, *args): + def __init__(self, db, *args, **kwargs): self.db = db - self._transaction_args = args + self._transaction_args = (args, kwargs) def __enter__(self): if self.db.transaction_depth() == 0: - self._helper = self.db.transaction(*self._transaction_args) + args, kwargs = self._transaction_args + self._helper = self.db.transaction(*args, **kwargs) elif isinstance(self.db.top_transaction(), _manual): raise ValueError('Cannot enter atomic commit block while in ' 'manual commit mode.') @@ -4084,12 +4085,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _transaction(_callable_context_manager): - def __init__(self, db, *args): + def __init__(self, db, *args, **kwargs): self.db = db - self._begin_args = args + self._begin_args = (args, kwargs) def _begin(self): - self.db.begin(*self._begin_args) + args, kwargs = self._begin_args + self.db.begin(*args, **kwargs) def commit(self, begin=True): self.db.commit() diff --git a/tests/base.py b/tests/base.py index 2694449d0..c7b782e91 100644 --- a/tests/base.py +++ b/tests/base.py @@ -79,8 +79,8 @@ def make_db_params(key): handler.setLevel(logging.DEBUG) -def new_connection(): - return db_loader(BACKEND, 'peewee_test') +def new_connection(**kwargs): + return db_loader(BACKEND, 'peewee_test', **kwargs) db = new_connection() diff --git a/tests/transactions.py b/tests/transactions.py index 673969e75..7c601b785 100644 --- a/tests/transactions.py +++ b/tests/transactions.py @@ -2,9 +2,12 @@ from .base import DatabaseTestCase from .base import IS_CRDB +from .base import IS_SQLITE from .base import ModelTestCase from .base import db +from .base import new_connection from .base import skip_if +from .base import skip_unless from .base_models import Register @@ -345,3 +348,31 @@ def test_session_commit(self): self.assertTrue(db.session_rollback()) self.assertRegister([1, 2, 3]) + + +@skip_unless(IS_SQLITE, 'requires sqlite for transaction lock type') +class TestTransactionLockType(BaseTransactionTestCase): + def test_lock_type(self): + db2 = new_connection(timeout=0.001) + db2.connect() + + with self.database.atomic(lock_type='EXCLUSIVE') as txn: + with self.assertRaises(OperationalError): + with db2.atomic(lock_type='IMMEDIATE') as t2: + self._save(1) + self._save(2) + self.assertRegister([2]) + + with self.database.atomic('IMMEDIATE') as txn: + with self.assertRaises(OperationalError): + with db2.atomic('EXCLUSIVE') as t2: + self._save(3) + self._save(4) + self.assertRegister([2, 4]) + + with self.database.transaction(lock_type='DEFERRED') as txn: + self._save(5) # Deferred -> Exclusive after our write. + with self.assertRaises(OperationalError): + with db2.transaction(lock_type='IMMEDIATE') as t2: + self._save(6) + self.assertRegister([2, 4, 5])