Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1065 Select for update #1085

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/piccolo/query_clauses/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ by modifying the return values.
./distinct
./freeze
./group_by
./lock_rows
./offset
./on_conflict
./output
Expand Down
93 changes: 93 additions & 0 deletions docs/src/piccolo/query_clauses/lock_rows.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
.. _lock_rows:

lock_rows
=========

You can use the ``lock_rows`` clause with the following queries:

* :ref:`Objects`
* :ref:`Select`

It returns a query that locks rows until the end of the transaction, generating a ``SELECT ... FOR UPDATE`` SQL statement or similar with other lock strengths.

.. note:: Postgres and CockroachDB only.

-------------------------------------------------------------------------------

Basic Usage
-----------

Basic usage without parameters:

.. code-block:: python

await Band.select(Band.name == 'Pythonistas').lock_rows()

Equivalent to:

.. code-block:: sql

SELECT ... FOR UPDATE


lock_strength
-------------

The parameter ``lock_strength`` controls the strength of the row lock when performing an operation in PostgreSQL.
The value can be a predefined constant from the ``LockStrength`` enum or one of the following strings (case-insensitive):

* ``UPDATE`` (default): Acquires an exclusive lock on the selected rows, preventing other transactions from modifying or locking them until the current transaction is complete.
* ``NO KEY UPDATE`` (Postgres only): Similar to ``UPDATE``, but allows other transactions to insert or delete rows that do not affect the primary key or unique constraints.
* ``KEY SHARE`` (Postgres only): Permits other transactions to acquire key-share or share locks, allowing non-key modifications while preventing updates or deletes.
* ``SHARE``: Acquires a shared lock, allowing other transactions to read the rows but not modify or lock them.

You can specify a different lock strength:

.. code-block:: python

await Band.select(Band.name == 'Pythonistas').lock_rows('SHARE')

Which is equivalent to:

.. code-block:: sql

SELECT ... FOR SHARE


nowait
------

If another transaction has already acquired a lock on one or more selected rows, an exception will be raised instead of
waiting for the other transaction to release the lock.

.. code-block:: python

await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', nowait=True)


skip_locked
-----------

Ignore locked rows.

.. code-block:: python

await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', skip_locked=True)


of
--

By default, if there are many tables in a query (e.g. when joining), all tables will be locked.
Using ``of``, you can specify which tables should be locked.

.. code-block:: python

await Band.select().where(Band.manager.name == 'Guido').lock_rows('UPDATE', of=(Band, ))


Learn more
----------

* `Postgres docs <https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE>`_
* `CockroachDB docs <https://www.cockroachlabs.com/docs/stable/select-for-update#lock-strengths>`_
5 changes: 5 additions & 0 deletions docs/src/piccolo/query_types/objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ limit

See :ref:`limit`.

lock_rows
~~~~~~~~

See :ref:`lock_rows`.

offset
~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/src/piccolo/query_types/select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ limit

See :ref:`limit`.


lock_rows
~~~~~~~~

See :ref:`lock_rows`.

offset
~~~~~~

Expand Down
26 changes: 26 additions & 0 deletions piccolo/query/methods/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
CallbackDelegate,
CallbackType,
LimitDelegate,
LockRowsDelegate,
LockStrength,
OffsetDelegate,
OrderByDelegate,
OrderByRaw,
Expand All @@ -27,6 +29,7 @@

if t.TYPE_CHECKING: # pragma: no cover
from piccolo.columns import Column
from piccolo.table import Table


###############################################################################
Expand Down Expand Up @@ -194,6 +197,7 @@ class Objects(
"callback_delegate",
"prefetch_delegate",
"where_delegate",
"lock_rows_delegate",
)

def __init__(
Expand All @@ -213,6 +217,7 @@ def __init__(
self.prefetch_delegate = PrefetchDelegate()
self.prefetch(*prefetch)
self.where_delegate = WhereDelegate()
self.lock_rows_delegate = LockRowsDelegate()

def output(self: Self, load_json: bool = False) -> Self:
self.output_delegate.output(
Expand Down Expand Up @@ -272,6 +277,26 @@ def first(self) -> First[TableInstance]:
self.limit_delegate.limit(1)
return First[TableInstance](query=self)

def lock_rows(
self: Self,
lock_strength: t.Union[
LockStrength,
t.Literal[
"UPDATE",
"NO KEY UPDATE",
"KEY SHARE",
"SHARE",
],
] = LockStrength.update,
nowait: bool = False,
skip_locked: bool = False,
of: t.Tuple[type[Table], ...] = (),
) -> Self:
self.lock_rows_delegate.lock_rows(
lock_strength, nowait, skip_locked, of
)
return self

def get(self, where: Combinable) -> Get[TableInstance]:
self.where_delegate.where(where)
self.limit_delegate.limit(1)
Expand Down Expand Up @@ -322,6 +347,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]:
"offset_delegate",
"output_delegate",
"order_by_delegate",
"lock_rows_delegate",
):
setattr(select, attr, getattr(self, attr))

Expand Down
34 changes: 34 additions & 0 deletions piccolo/query/methods/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
DistinctDelegate,
GroupByDelegate,
LimitDelegate,
LockRowsDelegate,
LockStrength,
OffsetDelegate,
OrderByDelegate,
OrderByRaw,
Expand Down Expand Up @@ -150,6 +152,7 @@ class Select(Query[TableInstance, t.List[t.Dict[str, t.Any]]]):
"output_delegate",
"callback_delegate",
"where_delegate",
"lock_rows_delegate",
)

def __init__(
Expand All @@ -174,6 +177,7 @@ def __init__(
self.output_delegate = OutputDelegate()
self.callback_delegate = CallbackDelegate()
self.where_delegate = WhereDelegate()
self.lock_rows_delegate = LockRowsDelegate()

self.columns(*columns_list)

Expand Down Expand Up @@ -219,6 +223,26 @@ def offset(self: Self, number: int) -> Self:
self.offset_delegate.offset(number)
return self

def lock_rows(
self: Self,
lock_strength: t.Union[
LockStrength,
t.Literal[
"UPDATE",
"NO KEY UPDATE",
"KEY SHARE",
"SHARE",
],
] = LockStrength.update,
nowait: bool = False,
skip_locked: bool = False,
of: t.Tuple[type[Table], ...] = (),
) -> Self:
self.lock_rows_delegate.lock_rows(
lock_strength, nowait, skip_locked, of
)
return self

async def _splice_m2m_rows(
self,
response: t.List[t.Dict[str, t.Any]],
Expand Down Expand Up @@ -618,6 +642,16 @@ def default_querystrings(self) -> t.Sequence[QueryString]:
query += "{}"
args.append(self.offset_delegate._offset.querystring)

if self.lock_rows_delegate._lock_rows:
if engine_type == "sqlite":
raise NotImplementedError(
"SQLite doesn't support row locking e.g. SELECT ... FOR "
"UPDATE"
)

query += "{}"
args.append(self.lock_rows_delegate._lock_rows.querystring)

querystring = QueryString(query, *args)

return [querystring]
Expand Down
88 changes: 88 additions & 0 deletions piccolo/query/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,91 @@ def on_conflict(
target=target, action=action_, values=values, where=where
)
)


class LockStrength(str, Enum):
"""
Specify lock strength

https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE
"""

update = "UPDATE"
no_key_update = "NO KEY UPDATE"
share = "SHARE"
key_share = "KEY SHARE"


@dataclass
class LockRows:
__slots__ = ("lock_strength", "nowait", "skip_locked", "of")

lock_strength: LockStrength
nowait: bool
skip_locked: bool
of: t.Tuple[t.Type[Table], ...]

def __post_init__(self):
if not isinstance(self.lock_strength, LockStrength):
raise TypeError("lock_strength must be a LockStrength")
if not isinstance(self.nowait, bool):
raise TypeError("nowait must be a bool")
if not isinstance(self.skip_locked, bool):
raise TypeError("skip_locked must be a bool")
if not isinstance(self.of, tuple) or not all(
hasattr(x, "_meta") for x in self.of
):
raise TypeError("of must be a tuple of Table")
if self.nowait and self.skip_locked:
raise TypeError(
"The nowait option cannot be used with skip_locked"
)

@property
def querystring(self) -> QueryString:
sql = f" FOR {self.lock_strength.value}"
if self.of:
tables = ", ".join(
i._meta.get_formatted_tablename() for i in self.of
)
sql += " OF " + tables
if self.nowait:
sql += " NOWAIT"
if self.skip_locked:
sql += " SKIP LOCKED"

return QueryString(sql)

def __str__(self) -> str:
return self.querystring.__str__()


@dataclass
class LockRowsDelegate:

_lock_rows: t.Optional[LockRows] = None

def lock_rows(
self,
lock_strength: t.Union[
LockStrength,
t.Literal[
"UPDATE",
"NO KEY UPDATE",
"KEY SHARE",
"SHARE",
],
] = LockStrength.update,
nowait=False,
skip_locked=False,
of: t.Tuple[type[Table], ...] = (),
):
lock_strength_: LockStrength
if isinstance(lock_strength, LockStrength):
lock_strength_ = lock_strength
elif isinstance(lock_strength, str):
lock_strength_ = LockStrength(lock_strength.upper())
else:
raise ValueError("Unrecognised `lock_strength` value.")

self._lock_rows = LockRows(lock_strength_, nowait, skip_locked, of)
34 changes: 34 additions & 0 deletions tests/table/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,40 @@ def test_select_raw(self):
response, [{"name": "Pythonistas", "popularity_log": 3.0}]
)

@pytest.mark.skipif(
is_running_sqlite(),
reason="SQLite doesn't support SELECT ... FOR UPDATE.",
)
def test_lock_rows(self):
"""
Make sure the for_update clause works.
"""
self.insert_rows()

query = Band.select()
self.assertNotIn("FOR UPDATE", query.__str__())

query = query.lock_rows()
self.assertTrue(query.__str__().endswith("FOR UPDATE"))

query = query.lock_rows(lock_strength="KEY SHARE")
self.assertTrue(query.__str__().endswith("FOR KEY SHARE"))

query = query.lock_rows(skip_locked=True)
self.assertTrue(query.__str__().endswith("FOR UPDATE SKIP LOCKED"))

query = query.lock_rows(nowait=True)
self.assertTrue(query.__str__().endswith("FOR UPDATE NOWAIT"))

query = query.lock_rows(of=(Band,))
self.assertTrue(query.__str__().endswith('FOR UPDATE OF "band"'))

with self.assertRaises(TypeError):
query = query.lock_rows(skip_locked=True, nowait=True)

response = query.run_sync()
assert response is not None


class TestSelectSecret(TestCase):
def setUp(self):
Expand Down
Loading