12
12
from saffier .types import DictAny
13
13
14
14
if typing .TYPE_CHECKING : # pragma: no cover
15
+ from saffier .db .connection import Database
15
16
from saffier .models import Model
16
17
17
18
@@ -25,7 +26,7 @@ class QuerySetProps:
25
26
"""
26
27
27
28
@property
28
- def database (self ):
29
+ def database (self ) -> "Database" :
29
30
return self .model_class ._meta .registry .database
30
31
31
32
@property
@@ -117,6 +118,7 @@ def _build_select(self):
117
118
if self .distinct_on :
118
119
expression = self ._build_select_distinct (self .distinct_on , expression = expression )
119
120
121
+ setattr (self , "_expression" , expression )
120
122
return expression
121
123
122
124
def _filter_query (self , exclude : bool = False , ** kwargs ):
@@ -232,6 +234,7 @@ def _clone(self) -> "QuerySet[SaffierModel]":
232
234
queryset ._order_by = copy .copy (self ._order_by )
233
235
queryset ._group_by = copy .copy (self ._group_by )
234
236
queryset .distinct_on = copy .copy (self .distinct_on )
237
+ queryset ._expression = self ._expression
235
238
return queryset
236
239
237
240
@@ -262,14 +265,30 @@ def __init__(
262
265
self ._order_by = [] if order_by is None else order_by
263
266
self ._group_by = [] if group_by is None else group_by
264
267
self .distinct_on = [] if distinct_on is None else distinct_on
268
+ self ._expression = None
265
269
266
270
def __get__ (self , instance , owner ):
267
271
return self .__class__ (model_class = owner )
268
272
273
+ @property
274
+ def sql (self ):
275
+ return str (self ._expression )
276
+
277
+ @sql .setter
278
+ def sql (self , value ):
279
+ setattr (self , "_expression" , value )
280
+
269
281
async def __aiter__ (self ) -> typing .AsyncIterator [SaffierModel ]:
270
282
for value in await self :
271
283
yield value
272
284
285
+ def _set_query_expression (self , expression : typing .Any ) -> None :
286
+ """
287
+ Sets the value of the sql property to the expression used.
288
+ """
289
+ self .sql = expression
290
+ self .model_class .raw_query = self .sql
291
+
273
292
def _filter_or_exclude (
274
293
self ,
275
294
clause : typing .Optional [sqlalchemy .sql .expression .BinaryExpression ] = None ,
@@ -389,6 +408,7 @@ async def exists(self) -> bool:
389
408
"""
390
409
expression = self ._build_select ()
391
410
expression = sqlalchemy .exists (expression ).select ()
411
+ self ._set_query_expression (expression )
392
412
return await self .database .fetch_val (expression )
393
413
394
414
async def count (self ) -> int :
@@ -397,6 +417,7 @@ async def count(self) -> int:
397
417
"""
398
418
expression = self ._build_select ().alias ("subquery_for_count" )
399
419
expression = sqlalchemy .func .count ().select ().select_from (expression )
420
+ self ._set_query_expression (expression )
400
421
return await self .database .fetch_val (expression )
401
422
402
423
async def get_or_none (self , ** kwargs ):
@@ -405,6 +426,7 @@ async def get_or_none(self, **kwargs):
405
426
"""
406
427
queryset = self .filter (** kwargs )
407
428
expression = queryset ._build_select ().limit (2 )
429
+ self ._set_query_expression (expression )
408
430
rows = await self .database .fetch_all (expression )
409
431
410
432
if not rows :
@@ -422,7 +444,12 @@ async def all(self, **kwargs):
422
444
return await queryset .filter (** kwargs ).all ()
423
445
424
446
expression = queryset ._build_select ()
447
+ self ._set_query_expression (expression )
448
+
425
449
rows = await queryset .database .fetch_all (expression )
450
+
451
+ # Attach the raw query to the object
452
+ queryset .model_class .raw_query = self .sql
426
453
return [
427
454
queryset .model_class ._from_row (row , select_related = self ._select_related )
428
455
for row in rows
@@ -437,6 +464,7 @@ async def get(self, **kwargs):
437
464
438
465
expression = self ._build_select ().limit (2 )
439
466
rows = await self .database .fetch_all (expression )
467
+ self ._set_query_expression (expression )
440
468
441
469
if not rows :
442
470
raise DoesNotFound ()
@@ -475,6 +503,7 @@ async def create(self, **kwargs):
475
503
kwargs = self ._validate_kwargs (** kwargs )
476
504
instance = self .model_class (** kwargs )
477
505
expression = self .table .insert ().values (** kwargs )
506
+ self ._set_query_expression (expression )
478
507
479
508
if self .pkname not in kwargs :
480
509
instance .pk = await self .database .execute (expression )
@@ -490,13 +519,57 @@ async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
490
519
new_objs = [self ._validate_kwargs (** obj ) for obj in objs ]
491
520
492
521
expression = self .table .insert ().values (new_objs )
522
+ self ._set_query_expression (expression )
493
523
await self .database .execute (expression )
494
524
525
+ async def bulk_update (self , objs : typing .List [SaffierModel ], fields : typing .List [str ]) -> None :
526
+ """
527
+ Bulk updates records in a table.
528
+
529
+ A similar solution was suggested here: https://github.com/encode/orm/pull/148
530
+
531
+ It is thought to be a clean approach to a simple problem so it was added here and
532
+ refactored to be compatible with Saffier.
533
+ """
534
+ new_fields = {}
535
+ for key , field in self .model_class .fields .items ():
536
+ if key in fields :
537
+ new_fields [key ] = field .validator
538
+
539
+ validator = Schema (fields = new_fields )
540
+
541
+ new_objs = []
542
+ for obj in objs :
543
+ new_obj = {}
544
+ for key , value in obj .__dict__ .items ():
545
+ if key in fields :
546
+ new_obj [key ] = self ._resolve_value (value )
547
+ new_objs .append (new_obj )
548
+
549
+ new_objs = [
550
+ self ._update_auto_now_fields (validator .validate (obj ), self .model_class .fields )
551
+ for obj in new_objs
552
+ ]
553
+
554
+ pk = getattr (self .table .c , self .pkname )
555
+ expression = self .table .update ().where (pk == sqlalchemy .bindparam (self .pkname ))
556
+ kwargs = {field : sqlalchemy .bindparam (field ) for obj in new_objs for field in obj .keys ()}
557
+ pks = [{self .pkname : getattr (obj , self .pkname )} for obj in objs ]
558
+
559
+ query_list = []
560
+ for pk , value in zip (pks , new_objs ):
561
+ query_list .append ({** pk , ** value })
562
+
563
+ expression = expression .values (kwargs )
564
+ self ._set_query_expression (expression )
565
+ await self .database .execute_many (str (expression ), query_list )
566
+
495
567
async def delete (self ) -> None :
496
568
expression = self .table .delete ()
497
569
for filter_clause in self .filter_clauses :
498
570
expression = expression .where (filter_clause )
499
571
572
+ self ._set_query_expression (expression )
500
573
await self .database .execute (expression )
501
574
502
575
async def update (self , ** kwargs ) -> None :
@@ -509,12 +582,13 @@ async def update(self, **kwargs) -> None:
509
582
510
583
validator = Schema (fields = fields )
511
584
kwargs = self ._update_auto_now_fields (validator .validate (kwargs ), self .model_class .fields )
512
- expr = self .table .update ().values (** kwargs )
585
+ expression = self .table .update ().values (** kwargs )
513
586
514
587
for filter_clause in self .filter_clauses :
515
- expr = expr .where (filter_clause )
588
+ expression = expression .where (filter_clause )
516
589
517
- await self .database .execute (expr )
590
+ self ._set_query_expression (expression )
591
+ await self .database .execute (expression )
518
592
519
593
async def get_or_create (
520
594
self , defaults : typing .Dict [str , typing .Any ], ** kwargs
0 commit comments