Skip to content

Commit 494ee68

Browse files
committed
➕ bulk_update, tests and raw_sql
* The bulk_update is an adaptation of encode/orm#148
1 parent 640fb03 commit 494ee68

File tree

7 files changed

+308
-22
lines changed

7 files changed

+308
-22
lines changed

saffier/core/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import typing
2+
from enum import Enum
23
from inspect import isclass
34

5+
from orjson import OPT_OMIT_MICROSECONDS # noqa
6+
from orjson import OPT_SERIALIZE_NUMPY # noqa
7+
from orjson import dumps
48
from typing_extensions import get_origin
59

610
from saffier.fields import DateField, DateTimeField
@@ -21,6 +25,16 @@ def _update_auto_now_fields(self, values: DictAny, fields: DictAny) -> DictAny:
2125
values[k] = v.validator.get_default_value()
2226
return values
2327

28+
def _resolve_value(self, value: typing.Any):
29+
if isinstance(value, dict):
30+
return dumps(
31+
value,
32+
option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS,
33+
).decode("utf-8")
34+
elif isinstance(value, Enum):
35+
return value.name
36+
return value
37+
2438

2539
def is_class_and_subclass(value: typing.Any, _type: typing.Any) -> bool:
2640
original = get_origin(value)

saffier/db/fields.py

-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def validate(self, value: typing.Any) -> typing.Any:
145145
elif not isinstance(value, str):
146146
raise self.validation_error("type")
147147

148-
# The null character is always invalid.
149148
value = value.replace("\0", "")
150149

151150
if self.trim_whitespace:

saffier/db/queryset.py

+78-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from saffier.types import DictAny
1313

1414
if typing.TYPE_CHECKING: # pragma: no cover
15+
from saffier.db.connection import Database
1516
from saffier.models import Model
1617

1718

@@ -25,7 +26,7 @@ class QuerySetProps:
2526
"""
2627

2728
@property
28-
def database(self):
29+
def database(self) -> "Database":
2930
return self.model_class._meta.registry.database
3031

3132
@property
@@ -117,6 +118,7 @@ def _build_select(self):
117118
if self.distinct_on:
118119
expression = self._build_select_distinct(self.distinct_on, expression=expression)
119120

121+
setattr(self, "_expression", expression)
120122
return expression
121123

122124
def _filter_query(self, exclude: bool = False, **kwargs):
@@ -232,6 +234,7 @@ def _clone(self) -> "QuerySet[SaffierModel]":
232234
queryset._order_by = copy.copy(self._order_by)
233235
queryset._group_by = copy.copy(self._group_by)
234236
queryset.distinct_on = copy.copy(self.distinct_on)
237+
queryset._expression = self._expression
235238
return queryset
236239

237240

@@ -262,14 +265,30 @@ def __init__(
262265
self._order_by = [] if order_by is None else order_by
263266
self._group_by = [] if group_by is None else group_by
264267
self.distinct_on = [] if distinct_on is None else distinct_on
268+
self._expression = None
265269

266270
def __get__(self, instance, owner):
267271
return self.__class__(model_class=owner)
268272

273+
@property
274+
def sql(self):
275+
return str(self._expression)
276+
277+
@sql.setter
278+
def sql(self, value):
279+
setattr(self, "_expression", value)
280+
269281
async def __aiter__(self) -> typing.AsyncIterator[SaffierModel]:
270282
for value in await self:
271283
yield value
272284

285+
def _set_query_expression(self, expression: typing.Any) -> None:
286+
"""
287+
Sets the value of the sql property to the expression used.
288+
"""
289+
self.sql = expression
290+
self.model_class.raw_query = self.sql
291+
273292
def _filter_or_exclude(
274293
self,
275294
clause: typing.Optional[sqlalchemy.sql.expression.BinaryExpression] = None,
@@ -389,6 +408,7 @@ async def exists(self) -> bool:
389408
"""
390409
expression = self._build_select()
391410
expression = sqlalchemy.exists(expression).select()
411+
self._set_query_expression(expression)
392412
return await self.database.fetch_val(expression)
393413

394414
async def count(self) -> int:
@@ -397,6 +417,7 @@ async def count(self) -> int:
397417
"""
398418
expression = self._build_select().alias("subquery_for_count")
399419
expression = sqlalchemy.func.count().select().select_from(expression)
420+
self._set_query_expression(expression)
400421
return await self.database.fetch_val(expression)
401422

402423
async def get_or_none(self, **kwargs):
@@ -405,6 +426,7 @@ async def get_or_none(self, **kwargs):
405426
"""
406427
queryset = self.filter(**kwargs)
407428
expression = queryset._build_select().limit(2)
429+
self._set_query_expression(expression)
408430
rows = await self.database.fetch_all(expression)
409431

410432
if not rows:
@@ -422,7 +444,12 @@ async def all(self, **kwargs):
422444
return await queryset.filter(**kwargs).all()
423445

424446
expression = queryset._build_select()
447+
self._set_query_expression(expression)
448+
425449
rows = await queryset.database.fetch_all(expression)
450+
451+
# Attach the raw query to the object
452+
queryset.model_class.raw_query = self.sql
426453
return [
427454
queryset.model_class._from_row(row, select_related=self._select_related)
428455
for row in rows
@@ -437,6 +464,7 @@ async def get(self, **kwargs):
437464

438465
expression = self._build_select().limit(2)
439466
rows = await self.database.fetch_all(expression)
467+
self._set_query_expression(expression)
440468

441469
if not rows:
442470
raise DoesNotFound()
@@ -475,6 +503,7 @@ async def create(self, **kwargs):
475503
kwargs = self._validate_kwargs(**kwargs)
476504
instance = self.model_class(**kwargs)
477505
expression = self.table.insert().values(**kwargs)
506+
self._set_query_expression(expression)
478507

479508
if self.pkname not in kwargs:
480509
instance.pk = await self.database.execute(expression)
@@ -490,13 +519,57 @@ async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
490519
new_objs = [self._validate_kwargs(**obj) for obj in objs]
491520

492521
expression = self.table.insert().values(new_objs)
522+
self._set_query_expression(expression)
493523
await self.database.execute(expression)
494524

525+
async def bulk_update(self, objs: typing.List[SaffierModel], fields: typing.List[str]) -> None:
526+
"""
527+
Bulk updates records in a table.
528+
529+
A similar solution was suggested here: https://github.com/encode/orm/pull/148
530+
531+
It is thought to be a clean approach to a simple problem so it was added here and
532+
refactored to be compatible with Saffier.
533+
"""
534+
new_fields = {}
535+
for key, field in self.model_class.fields.items():
536+
if key in fields:
537+
new_fields[key] = field.validator
538+
539+
validator = Schema(fields=new_fields)
540+
541+
new_objs = []
542+
for obj in objs:
543+
new_obj = {}
544+
for key, value in obj.__dict__.items():
545+
if key in fields:
546+
new_obj[key] = self._resolve_value(value)
547+
new_objs.append(new_obj)
548+
549+
new_objs = [
550+
self._update_auto_now_fields(validator.validate(obj), self.model_class.fields)
551+
for obj in new_objs
552+
]
553+
554+
pk = getattr(self.table.c, self.pkname)
555+
expression = self.table.update().where(pk == sqlalchemy.bindparam(self.pkname))
556+
kwargs = {field: sqlalchemy.bindparam(field) for obj in new_objs for field in obj.keys()}
557+
pks = [{self.pkname: getattr(obj, self.pkname)} for obj in objs]
558+
559+
query_list = []
560+
for pk, value in zip(pks, new_objs):
561+
query_list.append({**pk, **value})
562+
563+
expression = expression.values(kwargs)
564+
self._set_query_expression(expression)
565+
await self.database.execute_many(str(expression), query_list)
566+
495567
async def delete(self) -> None:
496568
expression = self.table.delete()
497569
for filter_clause in self.filter_clauses:
498570
expression = expression.where(filter_clause)
499571

572+
self._set_query_expression(expression)
500573
await self.database.execute(expression)
501574

502575
async def update(self, **kwargs) -> None:
@@ -509,12 +582,13 @@ async def update(self, **kwargs) -> None:
509582

510583
validator = Schema(fields=fields)
511584
kwargs = self._update_auto_now_fields(validator.validate(kwargs), self.model_class.fields)
512-
expr = self.table.update().values(**kwargs)
585+
expression = self.table.update().values(**kwargs)
513586

514587
for filter_clause in self.filter_clauses:
515-
expr = expr.where(filter_clause)
588+
expression = expression.where(filter_clause)
516589

517-
await self.database.execute(expr)
590+
self._set_query_expression(expression)
591+
await self.database.execute(expression)
518592

519593
async def get_or_create(
520594
self, defaults: typing.Dict[str, typing.Any], **kwargs

saffier/models.py

+9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Model(ModelMeta, ModelUtil):
1818
query = Manager()
1919
_meta = MetaInfo(None)
2020
_db_model: bool = False
21+
_raw_query: str = None
2122

2223
def __init__(self, **kwargs: DictAny) -> None:
2324
if "pk" in kwargs:
@@ -55,6 +56,14 @@ def pk(self):
5556
def pk(self, value):
5657
setattr(self, self.pkname, value)
5758

59+
@property
60+
def raw_query(self):
61+
return getattr(self, self._raw_query)
62+
63+
@raw_query.setter
64+
def raw_query(self, value):
65+
setattr(self, self.raw_query, value)
66+
5867
def __repr__(self):
5968
return f"<{self.__class__.__name__}: {self}>"
6069

tests/models/test_bulk_create.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import datetime
2+
from enum import Enum
3+
4+
import pytest
5+
from tests.settings import DATABASE_URL
6+
7+
import saffier
8+
from saffier import fields
9+
from saffier.db.connection import Database
10+
11+
pytestmark = pytest.mark.anyio
12+
13+
database = Database(DATABASE_URL)
14+
models = saffier.Registry(database=database)
15+
16+
17+
def time():
18+
return datetime.datetime.now().time()
19+
20+
21+
class StatusEnum(Enum):
22+
DRAFT = "Draft"
23+
RELEASED = "Released"
24+
25+
26+
class Product(saffier.Model):
27+
id = fields.IntegerField(primary_key=True)
28+
uuid = fields.UUIDField(null=True)
29+
created = fields.DateTimeField(default=datetime.datetime.now)
30+
created_day = fields.DateField(default=datetime.date.today)
31+
created_time = fields.TimeField(default=time)
32+
created_date = fields.DateField(auto_now_add=True)
33+
created_datetime = fields.DateTimeField(auto_now_add=True)
34+
updated_datetime = fields.DateTimeField(auto_now=True)
35+
updated_date = fields.DateField(auto_now=True)
36+
data = fields.JSONField(default={})
37+
description = fields.CharField(blank=True, max_length=255)
38+
huge_number = fields.BigIntegerField(default=0)
39+
price = fields.DecimalField(max_digits=5, decimal_places=2, null=True)
40+
status = fields.ChoiceField(StatusEnum, default=StatusEnum.DRAFT)
41+
value = fields.FloatField(null=True)
42+
43+
class Meta:
44+
registry = models
45+
46+
47+
@pytest.fixture(autouse=True, scope="module")
48+
async def create_test_database():
49+
await models.create_all()
50+
yield
51+
await models.drop_all()
52+
53+
54+
@pytest.fixture(autouse=True)
55+
async def rollback_transactions():
56+
with database.force_rollback():
57+
async with database:
58+
yield
59+
60+
61+
async def test_bulk_create():
62+
await Product.query.bulk_create(
63+
[
64+
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
65+
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
66+
]
67+
)
68+
products = await Product.query.all()
69+
assert len(products) == 2
70+
assert products[0].data == {"foo": 123}
71+
assert products[0].value == 123.456
72+
assert products[0].status == StatusEnum.RELEASED
73+
assert products[1].data == {"foo": 456}
74+
assert products[1].value == 456.789
75+
assert products[1].status == StatusEnum.DRAFT

0 commit comments

Comments
 (0)