Skip to content

Commit d0b4df0

Browse files
committed
Make ExcludedCol work with fields and use them when constructing SET clause
The recent changes to add support for custom update values had a side effect that upserts with PostGIS related fields would break. They would break while building the `SET` clause. Django would try to figure out the right placeholder for the expression, even though none is required. Since there was not associated field information, it couldn't figure it out. By passing the field information, we ensure we always build the SET clause correctly.
1 parent 92ae690 commit d0b4df0

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

psqlextra/expressions.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from django.db.models import CharField, expressions
1+
from typing import Union
2+
3+
from django.db.models import CharField, Field, expressions
24

35

46
class HStoreValue(expressions.Expression):
@@ -215,8 +217,18 @@ class ExcludedCol(expressions.Expression):
215217
See: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT
216218
"""
217219

218-
def __init__(self, name: str):
219-
self.name = name
220+
def __init__(self, field_or_name: Union[Field, str]):
221+
222+
# We support both field classes or just field names here. We prefer
223+
# fields because when the expression is compiled, it might need
224+
# the field information to figure out the correct placeholder.
225+
# Even though that isn't require for this particular expression.
226+
if isinstance(field_or_name, Field):
227+
super().__init__(field_or_name)
228+
self.name = field_or_name.column
229+
else:
230+
super().__init__(None)
231+
self.name = field_or_name
220232

221233
def as_sql(self, compiler, connection):
222234
quoted_name = connection.ops.quote_name(self.name)

psqlextra/query.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def _get_upsert_fields(self, kwargs):
591591
has_default = field.default != NOT_PROVIDED
592592
if field.name in kwargs or field.column in kwargs:
593593
insert_fields.append(field)
594-
update_values[field.name] = ExcludedCol(field.column)
594+
update_values[field.name] = ExcludedCol(field)
595595
continue
596596
elif has_default:
597597
insert_fields.append(field)
@@ -602,13 +602,13 @@ def _get_upsert_fields(self, kwargs):
602602
# instead of a concrete field, we have to handle that
603603
if field.primary_key is True and "pk" in kwargs:
604604
insert_fields.append(field)
605-
update_values[field.name] = ExcludedCol(field.column)
605+
update_values[field.name] = ExcludedCol(field)
606606
continue
607607

608608
if self._is_magical_field(model_instance, field, is_insert=True):
609609
insert_fields.append(field)
610610

611611
if self._is_magical_field(model_instance, field, is_insert=False):
612-
update_values[field.name] = ExcludedCol(field.column)
612+
update_values[field.name] = ExcludedCol(field)
613613

614614
return insert_fields, update_values

0 commit comments

Comments
 (0)