|
2 | 2 | # -*- coding: utf-8 -*-
|
3 | 3 | from typing import Any, Generic, Iterable, Sequence, Type
|
4 | 4 |
|
5 |
| -from sqlalchemy import Row, RowMapping, Select, delete, select, update |
| 5 | +from sqlalchemy import Row, RowMapping, Select, delete, inspect, select, update |
6 | 6 | from sqlalchemy.ext.asyncio import AsyncSession
|
7 | 7 |
|
8 |
| -from sqlalchemy_crud_plus.errors import MultipleResultsError |
| 8 | +from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError |
9 | 9 | from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema
|
10 | 10 | from sqlalchemy_crud_plus.utils import apply_sorting, count, parse_filters
|
11 | 11 |
|
12 | 12 |
|
13 | 13 | class CRUDPlus(Generic[Model]):
|
14 | 14 | def __init__(self, model: Type[Model]):
|
15 | 15 | self.model = model
|
| 16 | + self.primary_key = self._get_primary_key() |
| 17 | + |
| 18 | + def _get_primary_key(self): |
| 19 | + """ |
| 20 | + Dynamically retrieve the primary key column(s) for the model. |
| 21 | + """ |
| 22 | + mapper = inspect(self.model) |
| 23 | + primary_key = mapper.primary_key |
| 24 | + if len(primary_key) == 1: |
| 25 | + return primary_key[0] |
| 26 | + else: |
| 27 | + raise CompositePrimaryKeysError('Composite primary keys are not supported') |
16 | 28 |
|
17 | 29 | async def create_model(
|
18 | 30 | self,
|
@@ -69,7 +81,7 @@ async def select_model(self, session: AsyncSession, pk: int) -> Model | None:
|
69 | 81 | :param pk: The database primary key value.
|
70 | 82 | :return:
|
71 | 83 | """
|
72 |
| - stmt = select(self.model).where(self.model.id == pk) |
| 84 | + stmt = select(self.model).where(self.primary_key == pk) |
73 | 85 | query = await session.execute(stmt)
|
74 | 86 | return query.scalars().first()
|
75 | 87 |
|
@@ -166,7 +178,7 @@ async def update_model(
|
166 | 178 | instance_data = obj
|
167 | 179 | else:
|
168 | 180 | instance_data = obj.model_dump(exclude_unset=True)
|
169 |
| - stmt = update(self.model).where(self.model.id == pk).values(**instance_data) |
| 181 | + stmt = update(self.model).where(self.primary_key == pk).values(**instance_data) |
170 | 182 | result = await session.execute(stmt)
|
171 | 183 | if commit:
|
172 | 184 | await session.commit()
|
@@ -218,7 +230,7 @@ async def delete_model(
|
218 | 230 | :param commit: If `True`, commits the transaction immediately. Default is `False`.
|
219 | 231 | :return:
|
220 | 232 | """
|
221 |
| - stmt = delete(self.model).where(self.model.id == pk) |
| 233 | + stmt = delete(self.model).where(self.primary_key == pk) |
222 | 234 | result = await session.execute(stmt)
|
223 | 235 | if commit:
|
224 | 236 | await session.commit()
|
|
0 commit comments