Skip to content

Commit eb4b7ba

Browse files
committed
Allow specifying a specific constraint to use in ON CONFLICT
1 parent c451255 commit eb4b7ba

File tree

4 files changed

+87
-3
lines changed

4 files changed

+87
-3
lines changed

docs/source/conflict_handling.rst

+35
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,41 @@ Specifying multiple columns is necessary in case of a constraint that spans mult
8787
)
8888
8989
90+
Specific constraint
91+
*******************
92+
93+
Alternatively, instead of specifying the columns the constraint you're targetting applies to, you can also specify the exact constraint to use:
94+
95+
.. code-block:: python
96+
97+
from django.db import models
98+
from psqlextra.models import PostgresModel
99+
100+
class MyModel(PostgresModel)
101+
class Meta:
102+
constraints = [
103+
models.UniqueConstraint(
104+
name="myconstraint",
105+
fields=["first_name", "last_name"]
106+
),
107+
]
108+
109+
first_name = models.CharField(max_length=255)
110+
last_name = models.CharField(max_length=255)
111+
112+
constraint = next(
113+
constraint
114+
for constraint in MyModel._meta.constraints
115+
if constraint.name == "myconstraint"
116+
), None)
117+
118+
obj = (
119+
MyModel.objects
120+
.on_conflict(constraint, ConflictAction.UPDATE)
121+
.insert_and_get(first_name='Henk', last_name='Jansen')
122+
)
123+
124+
90125
HStore keys
91126
***********
92127
Catching conflicts in columns with a ``UNIQUE`` constraint on a :class:`~psqlextra.fields.HStoreField` key is also supported:

psqlextra/compiler.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,11 @@ def _rewrite_insert_on_conflict(
243243

244244
# build the conflict target, the columns to watch
245245
# for conflicts
246-
conflict_target = self._build_conflict_target()
246+
on_conflict_clause = self._build_on_conflict_clause()
247247
index_predicate = self.query.index_predicate # type: ignore[attr-defined]
248248
update_condition = self.query.conflict_update_condition # type: ignore[attr-defined]
249249

250-
rewritten_sql = f"{sql} ON CONFLICT {conflict_target}"
250+
rewritten_sql = f"{sql} {on_conflict_clause}"
251251

252252
if index_predicate:
253253
expr_sql, expr_params = self._compile_expression(index_predicate)
@@ -270,6 +270,21 @@ def _rewrite_insert_on_conflict(
270270

271271
return (rewritten_sql, params)
272272

273+
def _build_on_conflict_clause(self):
274+
if django.VERSION >= (2, 2):
275+
from django.db.models.constraints import BaseConstraint
276+
from django.db.models.indexes import Index
277+
278+
if isinstance(
279+
self.query.conflict_target, BaseConstraint
280+
) or isinstance(self.query.conflict_target, Index):
281+
return "ON CONFLICT ON CONSTRAINT %s" % self.qn(
282+
self.query.conflict_target.name
283+
)
284+
285+
conflict_target = self._build_conflict_target()
286+
return f"ON CONFLICT {conflict_target}"
287+
273288
def _build_conflict_target(self):
274289
"""Builds the `conflict_target` for the ON CONFLICT clause."""
275290

psqlextra/query.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from .sql import PostgresInsertQuery, PostgresQuery
2121
from .types import ConflictAction
2222

23-
ConflictTarget = List[Union[str, Tuple[str]]]
23+
if TYPE_CHECKING:
24+
from django.db.models.constraints import BaseConstraint
25+
from django.db.models.indexes import Index
26+
27+
ConflictTarget = Union[List[Union[str, Tuple[str]]], "BaseConstraint", "Index"]
2428

2529

2630
TModel = TypeVar("TModel", bound=models.Model, covariant=True)

tests/test_on_conflict_update.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import django
12
import pytest
23

34
from django.db import models
@@ -41,6 +42,35 @@ def test_on_conflict_update():
4142
assert obj2.cookies == "choco"
4243

4344

45+
@pytest.mark.skipif(
46+
django.VERSION < (2, 2),
47+
reason="Django < 2.2 doesn't implement constraints",
48+
)
49+
def test_on_conflict_update_by_unique_constraint():
50+
model = get_fake_model(
51+
{
52+
"title": models.CharField(max_length=255, null=True),
53+
},
54+
meta_options={
55+
"constraints": [
56+
models.UniqueConstraint(name="test_uniq", fields=["title"]),
57+
],
58+
},
59+
)
60+
61+
constraint = next(
62+
(
63+
constraint
64+
for constraint in model._meta.constraints
65+
if constraint.name == "test_uniq"
66+
)
67+
)
68+
69+
model.objects.on_conflict(constraint, ConflictAction.UPDATE).insert_and_get(
70+
title="title"
71+
)
72+
73+
4474
def test_on_conflict_update_foreign_key_by_object():
4575
"""Tests whether simple upsert works correctly when the conflicting field
4676
is a foreign key specified as an object."""

0 commit comments

Comments
 (0)