diff --git a/docs/src/piccolo/query_clauses/index.rst b/docs/src/piccolo/query_clauses/index.rst index f9167ff63..feac07b0f 100644 --- a/docs/src/piccolo/query_clauses/index.rst +++ b/docs/src/piccolo/query_clauses/index.rst @@ -25,6 +25,7 @@ by modifying the return values. ./distinct ./freeze ./group_by + ./lock_rows ./offset ./on_conflict ./output diff --git a/docs/src/piccolo/query_clauses/lock_rows.rst b/docs/src/piccolo/query_clauses/lock_rows.rst new file mode 100644 index 000000000..54d435326 --- /dev/null +++ b/docs/src/piccolo/query_clauses/lock_rows.rst @@ -0,0 +1,93 @@ +.. _lock_rows: + +lock_rows +========= + +You can use the ``lock_rows`` clause with the following queries: + +* :ref:`Objects` +* :ref:`Select` + +It returns a query that locks rows until the end of the transaction, generating a ``SELECT ... FOR UPDATE`` SQL statement or similar with other lock strengths. + +.. note:: Postgres and CockroachDB only. + +------------------------------------------------------------------------------- + +Basic Usage +----------- + +Basic usage without parameters: + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows() + +Equivalent to: + +.. code-block:: sql + + SELECT ... FOR UPDATE + + +lock_strength +------------- + +The parameter ``lock_strength`` controls the strength of the row lock when performing an operation in PostgreSQL. +The value can be a predefined constant from the ``LockStrength`` enum or one of the following strings (case-insensitive): + +* ``UPDATE`` (default): Acquires an exclusive lock on the selected rows, preventing other transactions from modifying or locking them until the current transaction is complete. +* ``NO KEY UPDATE`` (Postgres only): Similar to ``UPDATE``, but allows other transactions to insert or delete rows that do not affect the primary key or unique constraints. +* ``KEY SHARE`` (Postgres only): Permits other transactions to acquire key-share or share locks, allowing non-key modifications while preventing updates or deletes. +* ``SHARE``: Acquires a shared lock, allowing other transactions to read the rows but not modify or lock them. + +You can specify a different lock strength: + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows('SHARE') + +Which is equivalent to: + +.. code-block:: sql + + SELECT ... FOR SHARE + + +nowait +------ + +If another transaction has already acquired a lock on one or more selected rows, an exception will be raised instead of +waiting for the other transaction to release the lock. + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', nowait=True) + + +skip_locked +----------- + +Ignore locked rows. + +.. code-block:: python + + await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', skip_locked=True) + + +of +-- + +By default, if there are many tables in a query (e.g. when joining), all tables will be locked. +Using ``of``, you can specify which tables should be locked. + +.. code-block:: python + + await Band.select().where(Band.manager.name == 'Guido').lock_rows('UPDATE', of=(Band, )) + + +Learn more +---------- + +* `Postgres docs `_ +* `CockroachDB docs `_ diff --git a/docs/src/piccolo/query_types/objects.rst b/docs/src/piccolo/query_types/objects.rst index 968aefcab..66acf9413 100644 --- a/docs/src/piccolo/query_types/objects.rst +++ b/docs/src/piccolo/query_types/objects.rst @@ -335,6 +335,11 @@ limit See :ref:`limit`. +lock_rows +~~~~~~~~ + +See :ref:`lock_rows`. + offset ~~~~~~ diff --git a/docs/src/piccolo/query_types/select.rst b/docs/src/piccolo/query_types/select.rst index 092291e4c..96c8b06d0 100644 --- a/docs/src/piccolo/query_types/select.rst +++ b/docs/src/piccolo/query_types/select.rst @@ -376,6 +376,12 @@ limit See :ref:`limit`. + +lock_rows +~~~~~~~~ + +See :ref:`lock_rows`. + offset ~~~~~~ diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 7f2b5aaed..41d105b9f 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -13,6 +13,8 @@ CallbackDelegate, CallbackType, LimitDelegate, + LockRowsDelegate, + LockStrength, OffsetDelegate, OrderByDelegate, OrderByRaw, @@ -27,6 +29,7 @@ if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns import Column + from piccolo.table import Table ############################################################################### @@ -194,6 +197,7 @@ class Objects( "callback_delegate", "prefetch_delegate", "where_delegate", + "lock_rows_delegate", ) def __init__( @@ -213,6 +217,7 @@ def __init__( self.prefetch_delegate = PrefetchDelegate() self.prefetch(*prefetch) self.where_delegate = WhereDelegate() + self.lock_rows_delegate = LockRowsDelegate() def output(self: Self, load_json: bool = False) -> Self: self.output_delegate.output( @@ -272,6 +277,26 @@ def first(self) -> First[TableInstance]: self.limit_delegate.limit(1) return First[TableInstance](query=self) + def lock_rows( + self: Self, + lock_strength: t.Union[ + LockStrength, + t.Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + ], + ] = LockStrength.update, + nowait: bool = False, + skip_locked: bool = False, + of: t.Tuple[type[Table], ...] = (), + ) -> Self: + self.lock_rows_delegate.lock_rows( + lock_strength, nowait, skip_locked, of + ) + return self + def get(self, where: Combinable) -> Get[TableInstance]: self.where_delegate.where(where) self.limit_delegate.limit(1) @@ -322,6 +347,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]: "offset_delegate", "output_delegate", "order_by_delegate", + "lock_rows_delegate", ): setattr(select, attr, getattr(self, attr)) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 0c590918b..5d7856c5a 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -19,6 +19,8 @@ DistinctDelegate, GroupByDelegate, LimitDelegate, + LockRowsDelegate, + LockStrength, OffsetDelegate, OrderByDelegate, OrderByRaw, @@ -150,6 +152,7 @@ class Select(Query[TableInstance, t.List[t.Dict[str, t.Any]]]): "output_delegate", "callback_delegate", "where_delegate", + "lock_rows_delegate", ) def __init__( @@ -174,6 +177,7 @@ def __init__( self.output_delegate = OutputDelegate() self.callback_delegate = CallbackDelegate() self.where_delegate = WhereDelegate() + self.lock_rows_delegate = LockRowsDelegate() self.columns(*columns_list) @@ -219,6 +223,26 @@ def offset(self: Self, number: int) -> Self: self.offset_delegate.offset(number) return self + def lock_rows( + self: Self, + lock_strength: t.Union[ + LockStrength, + t.Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + ], + ] = LockStrength.update, + nowait: bool = False, + skip_locked: bool = False, + of: t.Tuple[type[Table], ...] = (), + ) -> Self: + self.lock_rows_delegate.lock_rows( + lock_strength, nowait, skip_locked, of + ) + return self + async def _splice_m2m_rows( self, response: t.List[t.Dict[str, t.Any]], @@ -618,6 +642,16 @@ def default_querystrings(self) -> t.Sequence[QueryString]: query += "{}" args.append(self.offset_delegate._offset.querystring) + if self.lock_rows_delegate._lock_rows: + if engine_type == "sqlite": + raise NotImplementedError( + "SQLite doesn't support row locking e.g. SELECT ... FOR " + "UPDATE" + ) + + query += "{}" + args.append(self.lock_rows_delegate._lock_rows.querystring) + querystring = QueryString(query, *args) return [querystring] diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index d9d5f84ca..b1b726cf3 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -784,3 +784,91 @@ def on_conflict( target=target, action=action_, values=values, where=where ) ) + + +class LockStrength(str, Enum): + """ + Specify lock strength + + https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE + """ + + update = "UPDATE" + no_key_update = "NO KEY UPDATE" + share = "SHARE" + key_share = "KEY SHARE" + + +@dataclass +class LockRows: + __slots__ = ("lock_strength", "nowait", "skip_locked", "of") + + lock_strength: LockStrength + nowait: bool + skip_locked: bool + of: t.Tuple[t.Type[Table], ...] + + def __post_init__(self): + if not isinstance(self.lock_strength, LockStrength): + raise TypeError("lock_strength must be a LockStrength") + if not isinstance(self.nowait, bool): + raise TypeError("nowait must be a bool") + if not isinstance(self.skip_locked, bool): + raise TypeError("skip_locked must be a bool") + if not isinstance(self.of, tuple) or not all( + hasattr(x, "_meta") for x in self.of + ): + raise TypeError("of must be a tuple of Table") + if self.nowait and self.skip_locked: + raise TypeError( + "The nowait option cannot be used with skip_locked" + ) + + @property + def querystring(self) -> QueryString: + sql = f" FOR {self.lock_strength.value}" + if self.of: + tables = ", ".join( + i._meta.get_formatted_tablename() for i in self.of + ) + sql += " OF " + tables + if self.nowait: + sql += " NOWAIT" + if self.skip_locked: + sql += " SKIP LOCKED" + + return QueryString(sql) + + def __str__(self) -> str: + return self.querystring.__str__() + + +@dataclass +class LockRowsDelegate: + + _lock_rows: t.Optional[LockRows] = None + + def lock_rows( + self, + lock_strength: t.Union[ + LockStrength, + t.Literal[ + "UPDATE", + "NO KEY UPDATE", + "KEY SHARE", + "SHARE", + ], + ] = LockStrength.update, + nowait=False, + skip_locked=False, + of: t.Tuple[type[Table], ...] = (), + ): + lock_strength_: LockStrength + if isinstance(lock_strength, LockStrength): + lock_strength_ = lock_strength + elif isinstance(lock_strength, str): + lock_strength_ = LockStrength(lock_strength.upper()) + else: + raise ValueError("Unrecognised `lock_strength` value.") + + self._lock_rows = LockRows(lock_strength_, nowait, skip_locked, of) diff --git a/tests/table/test_select.py b/tests/table/test_select.py index ebf2c3ff8..74b876d5f 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -1028,6 +1028,40 @@ def test_select_raw(self): response, [{"name": "Pythonistas", "popularity_log": 3.0}] ) + @pytest.mark.skipif( + is_running_sqlite(), + reason="SQLite doesn't support SELECT ... FOR UPDATE.", + ) + def test_lock_rows(self): + """ + Make sure the for_update clause works. + """ + self.insert_rows() + + query = Band.select() + self.assertNotIn("FOR UPDATE", query.__str__()) + + query = query.lock_rows() + self.assertTrue(query.__str__().endswith("FOR UPDATE")) + + query = query.lock_rows(lock_strength="KEY SHARE") + self.assertTrue(query.__str__().endswith("FOR KEY SHARE")) + + query = query.lock_rows(skip_locked=True) + self.assertTrue(query.__str__().endswith("FOR UPDATE SKIP LOCKED")) + + query = query.lock_rows(nowait=True) + self.assertTrue(query.__str__().endswith("FOR UPDATE NOWAIT")) + + query = query.lock_rows(of=(Band,)) + self.assertTrue(query.__str__().endswith('FOR UPDATE OF "band"')) + + with self.assertRaises(TypeError): + query = query.lock_rows(skip_locked=True, nowait=True) + + response = query.run_sync() + assert response is not None + class TestSelectSecret(TestCase): def setUp(self):