diff --git a/ifsguid/crud.py b/ifsguid/crud.py index e965484..e8bc537 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -1,7 +1,8 @@ from typing import List from uuid import UUID -from sqlalchemy.orm import Session, joinedload, selectinload +from sqlalchemy import delete, update +from sqlalchemy.orm import joinedload from sqlalchemy.future import select from sqlalchemy.ext.asyncio import AsyncSession @@ -40,13 +41,11 @@ async def create_interaction( async def delete_interaction(db: AsyncSession, id: UUID) -> None: - interaction = ( - db.query(models.Interaction).filter(models.Interaction.id == id).first() - ) + stmt = delete(models.Interaction).where(models.Interaction.id == id) + result = await db.execute(stmt) - if interaction is not None: - db.delete(interaction) - db.commit() + if result.rowcount: + await db.commit() return True return False @@ -55,14 +54,16 @@ async def delete_interaction(db: AsyncSession, id: UUID) -> None: async def update_interaction( db: AsyncSession, id: UUID, settings: schemas.Settings ) -> models.Interaction: - interaction: models.Interaction = ( - db.query(models.Interaction).filter(models.Interaction.id == id).first() + stmt = ( + update(models.Interaction). + where(models.Interaction.id == id). + values(settings=settings) ) + result = await db.execute(stmt) - if interaction is not None: - interaction.settings = settings - db.commit() - return interaction + if result.rowcount: + await db.commit() + return True return None @@ -70,15 +71,16 @@ async def update_interaction( async def get_messages( db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10 ) -> List[models.Message]: - query = db.query(models.Message) + stmt = select(models.Message) if interaction_id is not None: - query = query.filter(models.Message.interaction_id == interaction_id) + stmt = stmt.where(models.Message.interaction_id == interaction_id) if page is not None: - query = query.offset((page - 1) * per_page).limit(per_page) + stmt = stmt.offset((page - 1) * per_page).limit(per_page) - return query.all() + result = await db.execute(stmt) + return result.scalars().all() async def create_message( diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index 2329884..b16f4b4 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -85,19 +85,18 @@ async def get_all_message_in_interaction( per_page: Optional[int] = None, db: AsyncSession = Depends(get_db), ) -> List[schemas.Message]: - interaction = crud.get_interaction(db=db, id=str(interaction_id)) + interaction = await crud.get_interaction(db=db, id=str(interaction_id)) if not interaction: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Interaction not found" ) - return [ - schemas.Message.model_validate(message) - for message in crud.get_messages( - db=db, interaction_id=str(interaction_id), page=page, per_page=per_page - ) - ] + messages = await crud.get_messages( + db=db, interaction_id=str(interaction_id), page=page, per_page=per_page + ) + + return [schemas.Message.model_validate(message) for message in messages] @router.post(