diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..7687239 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,7 @@ +coverage: + status: + project: + default: + # basic + target: auto + threshold: null diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index fc1eea4..853d755 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -36,12 +36,12 @@ jobs: run: | poetry run pytest -s env: - SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db + SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db - name: Code Coverage run: | poetry run pytest --cov=./ifsguid --cov-report=xml --cov-report=term-missing env: - SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db + SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 env: diff --git a/README.md b/README.md index 81813bd..fc8fce2 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,5 @@ To launch an API instance, you should: You can also run the project via `docker-compose` (i.e. `docker compose up -d`) on port `80` in which you would need the [.docker.env](/.docker.env) containing the following variable to create the database: ``` -SQLALCHEMY_DATABASE_URI=postgresql://:@ifsguid_db/ +SQLALCHEMY_DATABASE_URI=postgresql+asyncpg://:@ifsguid_db/ ``` - - diff --git a/alembic/env.py b/alembic/env.py index ece0603..9022ef9 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -59,7 +59,9 @@ def run_migrations_online() -> None: """ configuration = config.get_section(config.config_ini_section) - configuration["sqlalchemy.url"] = f"{settings.SQLALCHEMY_DATABASE_URI}" + configuration[ + "sqlalchemy.url" + ] = settings.SQLALCHEMY_DATABASE_URI.unicode_string().replace("+asyncpg", "") connectable = engine_from_config( configuration, prefix="sqlalchemy.", diff --git a/ifsguid/config.py b/ifsguid/config.py index 8aefecc..bca64c9 100644 --- a/ifsguid/config.py +++ b/ifsguid/config.py @@ -21,7 +21,7 @@ def assemble_db_connection(cls, v: Optional[str], values: ValidationInfo) -> Any return v print("Creating SQLALCHEMY_DATABASE_URI from .env file ...") return PostgresDsn.build( - scheme="postgresql", + scheme="postgresql+asyncpg", username=values.data.get("POSTGRES_USER"), password=values.data.get("POSTGRES_PASSWORD"), host=values.data.get("POSTGRES_SERVER"), diff --git a/ifsguid/crud.py b/ifsguid/crud.py index e91b140..05225c8 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -1,79 +1,90 @@ from typing import List from uuid import UUID -from sqlalchemy.orm import Session +from sqlalchemy import delete, update +from sqlalchemy.orm import joinedload +from sqlalchemy.future import select +from sqlalchemy.ext.asyncio import AsyncSession -from . import models, schemas, utils +from . import models, schemas -def get_interactions( - db: Session, page: int = None, per_page: int = 10 +async def get_interactions( + db: AsyncSession, page: int = None, per_page: int = 10 ) -> List[models.Interaction]: - query = db.query(models.Interaction) + stmt = select(models.Interaction).options(joinedload(models.Interaction.messages)) if page is not None: - query = query.offset((page - 1) * per_page).limit(per_page) + stmt = stmt.offset((page - 1) * per_page).limit(per_page) - return query.all() + result = await db.execute(stmt) + return result.scalars().unique().all() -def get_interaction(db: Session, id: UUID) -> models.Interaction: - return db.query(models.Interaction).filter(models.Interaction.id == id).first() +async def get_interaction(db: AsyncSession, id: UUID) -> models.Interaction: + stmt = select(models.Interaction).where(models.Interaction.id == id) + result = await db.execute(stmt) + return result.scalar() -def create_interaction(db: Session, settings: schemas.Settings) -> models.Interaction: +async def create_interaction( + db: AsyncSession, settings: schemas.Settings +) -> models.Interaction: interaction = models.Interaction( settings=settings.model_dump(), ) db.add(interaction) - db.commit() + await db.commit() + await db.refresh(interaction) + return interaction -def delete_interaction(db: Session, id: UUID) -> None: - interaction = ( - db.query(models.Interaction).filter(models.Interaction.id == id).first() - ) +async def delete_interaction(db: AsyncSession, id: UUID) -> None: + stmt = delete(models.Interaction).where(models.Interaction.id == id) + result = await db.execute(stmt) - if interaction is not None: - db.delete(interaction) - db.commit() + if result.rowcount: + await db.commit() return True return False -def update_interaction( - db: Session, id: UUID, settings: schemas.Settings +async def update_interaction( + db: AsyncSession, id: UUID, settings: schemas.Settings ) -> models.Interaction: - interaction: models.Interaction = ( - db.query(models.Interaction).filter(models.Interaction.id == id).first() + stmt = ( + update(models.Interaction) + .where(models.Interaction.id == id) + .values(settings=settings) ) + result = await db.execute(stmt) - if interaction is not None: - interaction.settings = settings - db.commit() - return interaction + if result.rowcount: + await db.commit() + return True return None -def get_messages( - db: Session, interaction_id: UUID = None, page: int = None, per_page: int = 10 +async def get_messages( + db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10 ) -> List[models.Message]: - query = db.query(models.Message) + stmt = select(models.Message) if interaction_id is not None: - query = query.filter(models.Message.interaction_id == interaction_id) + stmt = stmt.where(models.Message.interaction_id == interaction_id) if page is not None: - query = query.offset((page - 1) * per_page).limit(per_page) + stmt = stmt.offset((page - 1) * per_page).limit(per_page) - return query.all() + result = await db.execute(stmt) + return result.scalars().all() -def create_message( - db: Session, messages: List[schemas.MessageCreate], interaction_id: UUID +async def create_message( + db: AsyncSession, messages: List[schemas.MessageCreate], interaction_id: UUID ) -> List[models.Message]: messages_db = [] for msg in messages: @@ -84,5 +95,5 @@ def create_message( db.add(message) messages_db.append(message) - db.commit() + await db.commit() return messages_db diff --git a/ifsguid/database.py b/ifsguid/database.py index b5467af..27da591 100644 --- a/ifsguid/database.py +++ b/ifsguid/database.py @@ -1,9 +1,16 @@ -from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine from sqlalchemy.orm import sessionmaker from .config import settings -engine = create_engine( - settings.SQLALCHEMY_DATABASE_URI.unicode_string(), pool_pre_ping=True +engine: AsyncEngine = create_async_engine( + settings.SQLALCHEMY_DATABASE_URI.unicode_string() +) +async_session = sessionmaker( + autocommit=False, + autoflush=False, + future=True, + bind=engine, + class_=AsyncSession, + expire_on_commit=False, ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True, bind=engine) diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index dd2019e..b16f4b4 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -2,24 +2,21 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session +from g4f import ModelUtils from . import crud, schemas, modules -from .database import SessionLocal +from .database import async_session, AsyncSession router = APIRouter() -def get_db() -> SessionLocal: - try: - db = SessionLocal() - yield db - finally: - db.close() +async def get_db() -> AsyncSession: + async with async_session() as session: + yield session @router.get("/", response_model=str) -async def get_root(db: Session = Depends(get_db)) -> str: +async def get_root(db: AsyncSession = Depends(get_db)) -> str: return "Hello from IFSGuid!" @@ -27,11 +24,12 @@ async def get_root(db: Session = Depends(get_db)) -> str: async def get_all_interactions( page: Optional[int] = None, per_page: Optional[int] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ) -> List[schemas.Interaction]: + interactions = await crud.get_interactions(db=db, page=page, per_page=per_page) + return [ - schemas.Interaction.model_validate(interaction) - for interaction in crud.get_interactions(db=db, page=page, per_page=per_page) + schemas.Interaction.model_validate(interaction) for interaction in interactions ] @@ -39,7 +37,7 @@ async def get_all_interactions( "/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False ) async def get_interactions( - id: UUID, db: Session = Depends(get_db) + id: UUID, db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" @@ -48,15 +46,20 @@ async def get_interactions( @router.post("/interactions", response_model=schemas.Interaction) async def create_interactions( - settings: schemas.Settings, db: Session = Depends(get_db) + prompt: schemas.Prompt, + chat_model: schemas.ChatModel = Depends(), + db: AsyncSession = Depends(get_db), ) -> schemas.Interaction: - return schemas.Interaction.model_validate( - crud.create_interaction(db=db, settings=settings) + settings = schemas.Settings( + model=chat_model.model, prompt=prompt.prompt, role=prompt.role ) + interaction = await crud.create_interaction(db=db, settings=settings) + + return schemas.Interaction.model_validate(interaction) @router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False) -async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None: +async def delete_interaction(id: UUID, db: AsyncSession = Depends(get_db)) -> None: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" ) @@ -66,7 +69,7 @@ async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None: "/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False ) async def update_interaction( - id: UUID, settings: schemas.Settings, db: Session = Depends(get_db) + id: UUID, settings: schemas.Settings, db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" @@ -80,30 +83,31 @@ async def get_all_message_in_interaction( interaction_id: UUID, page: Optional[int] = None, per_page: Optional[int] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ) -> List[schemas.Message]: - interaction = crud.get_interaction(db=db, id=str(interaction_id)) + interaction = await crud.get_interaction(db=db, id=str(interaction_id)) if not interaction: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Interaction not found" ) - return [ - schemas.Message.model_validate(message) - for message in crud.get_messages( - db=db, interaction_id=str(interaction_id), page=page, per_page=per_page - ) - ] + messages = await crud.get_messages( + db=db, interaction_id=str(interaction_id), page=page, per_page=per_page + ) + + return [schemas.Message.model_validate(message) for message in messages] @router.post( "/interactions/{interactions_id}/messages", response_model=List[schemas.Message] ) async def create_message( - interaction_id: UUID, message: schemas.MessageCreate, db: Session = Depends(get_db) + interaction_id: UUID, + message: schemas.MessageCreate, + db: AsyncSession = Depends(get_db), ) -> schemas.Message: - interaction = crud.get_interaction(db=db, id=str(interaction_id)) + interaction = await crud.get_interaction(db=db, id=str(interaction_id)) if not interaction: raise HTTPException( @@ -114,17 +118,16 @@ async def create_message( messages = [] if message.role == "human": - ai_content = modules.generate_ai_response( - content=message.content, model=interaction.settings.model_name + ai_content = await modules.generate_ai_response( + content=message.content, + model=ModelUtils.convert[interaction.settings.model], ) ai_message = schemas.MessageCreate(role="ai", content=ai_content) messages.append(message) messages.append(ai_message) - return [ - schemas.Message.model_validate(message) - for message in crud.create_message( - db=db, messages=messages, interaction_id=str(interaction_id) - ) - ] + messages = await crud.create_message( + db=db, messages=messages, interaction_id=str(interaction_id) + ) + return [schemas.Message.model_validate(message) for message in messages] diff --git a/ifsguid/models.py b/ifsguid/models.py index 7c67ea0..62e1251 100644 --- a/ifsguid/models.py +++ b/ifsguid/models.py @@ -53,7 +53,7 @@ class Interaction(Base): ) settings = Column(JSON) - messages = relationship("Message", back_populates="interaction") + messages = relationship("Message", back_populates="interaction", lazy="selectin") class Message(Base): diff --git a/ifsguid/modules.py b/ifsguid/modules.py index b614674..861f73c 100644 --- a/ifsguid/modules.py +++ b/ifsguid/modules.py @@ -1,13 +1,19 @@ import g4f +import traceback g4f.debug.logging = True g4f.check_version = False -def generate_ai_response( - content: str, model: str = "GPT4" -) -> str: # TODO: use async version for future - return g4f.ChatCompletion.create( - model=g4f.models.gpt_4 if model == "GPT4" else g4f.models.gpt_35_turbo, - messages=[{"role": "human", "content": content}], - ) +async def generate_ai_response( + content: str, model: g4f.Model = g4f.models.default +) -> str: + try: + response = await g4f.ChatCompletion.create_async( + model=model, + messages=[{"role": "human", "content": content}], + ) + return response + except Exception: + traceback.print_exc() + return "error!" diff --git a/ifsguid/schemas.py b/ifsguid/schemas.py index de6b06d..c005748 100644 --- a/ifsguid/schemas.py +++ b/ifsguid/schemas.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, ConfigDict from typing import List, Literal +from g4f import _all_models + class MessageCreate(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -18,12 +20,19 @@ class Message(MessageCreate): created_at: datetime -class Settings(BaseModel): - model_name: Literal["GPT4", "GPT3"] = "GPT3" +class ChatModel(BaseModel): + model: Literal[tuple(_all_models)] = "gpt-3.5-turbo" + + +class Prompt(BaseModel): role: Literal["System"] = "System" prompt: str +class Settings(ChatModel, Prompt): + pass + + class InteractionCreate(BaseModel): model_config = ConfigDict(from_attributes=True) diff --git a/poetry.lock b/poetry.lock index 7d99f01..6585efc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -133,6 +133,21 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosqlite" +version = "0.19.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"}, + {file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"}, +] + +[package.extras] +dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] +docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"] + [[package]] name = "alembic" version = "1.12.1" @@ -208,6 +223,63 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "asyncpg" +version = "0.29.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "asyncpg-0.29.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72fd0ef9f00aeed37179c62282a3d14262dbbafb74ec0ba16e1b1864d8a12169"}, + {file = "asyncpg-0.29.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52e8f8f9ff6e21f9b39ca9f8e3e33a5fcdceaf5667a8c5c32bee158e313be385"}, + {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e6823a7012be8b68301342ba33b4740e5a166f6bbda0aee32bc01638491a22"}, + {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:746e80d83ad5d5464cfbf94315eb6744222ab00aa4e522b704322fb182b83610"}, + {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ff8e8109cd6a46ff852a5e6bab8b0a047d7ea42fcb7ca5ae6eaae97d8eacf397"}, + {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:97eb024685b1d7e72b1972863de527c11ff87960837919dac6e34754768098eb"}, + {file = "asyncpg-0.29.0-cp310-cp310-win32.whl", hash = "sha256:5bbb7f2cafd8d1fa3e65431833de2642f4b2124be61a449fa064e1a08d27e449"}, + {file = "asyncpg-0.29.0-cp310-cp310-win_amd64.whl", hash = "sha256:76c3ac6530904838a4b650b2880f8e7af938ee049e769ec2fba7cd66469d7772"}, + {file = "asyncpg-0.29.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4900ee08e85af01adb207519bb4e14b1cae8fd21e0ccf80fac6aa60b6da37b4"}, + {file = "asyncpg-0.29.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a65c1dcd820d5aea7c7d82a3fdcb70e096f8f70d1a8bf93eb458e49bfad036ac"}, + {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b52e46f165585fd6af4863f268566668407c76b2c72d366bb8b522fa66f1870"}, + {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc600ee8ef3dd38b8d67421359779f8ccec30b463e7aec7ed481c8346decf99f"}, + {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:039a261af4f38f949095e1e780bae84a25ffe3e370175193174eb08d3cecab23"}, + {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6feaf2d8f9138d190e5ec4390c1715c3e87b37715cd69b2c3dfca616134efd2b"}, + {file = "asyncpg-0.29.0-cp311-cp311-win32.whl", hash = "sha256:1e186427c88225ef730555f5fdda6c1812daa884064bfe6bc462fd3a71c4b675"}, + {file = "asyncpg-0.29.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfe73ffae35f518cfd6e4e5f5abb2618ceb5ef02a2365ce64f132601000587d3"}, + {file = "asyncpg-0.29.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6011b0dc29886ab424dc042bf9eeb507670a3b40aece3439944006aafe023178"}, + {file = "asyncpg-0.29.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b544ffc66b039d5ec5a7454667f855f7fec08e0dfaf5a5490dfafbb7abbd2cfb"}, + {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d84156d5fb530b06c493f9e7635aa18f518fa1d1395ef240d211cb563c4e2364"}, + {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54858bc25b49d1114178d65a88e48ad50cb2b6f3e475caa0f0c092d5f527c106"}, + {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bde17a1861cf10d5afce80a36fca736a86769ab3579532c03e45f83ba8a09c59"}, + {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:37a2ec1b9ff88d8773d3eb6d3784dc7e3fee7756a5317b67f923172a4748a175"}, + {file = "asyncpg-0.29.0-cp312-cp312-win32.whl", hash = "sha256:bb1292d9fad43112a85e98ecdc2e051602bce97c199920586be83254d9dafc02"}, + {file = "asyncpg-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe"}, + {file = "asyncpg-0.29.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0009a300cae37b8c525e5b449233d59cd9868fd35431abc470a3e364d2b85cb9"}, + {file = "asyncpg-0.29.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cad1324dbb33f3ca0cd2074d5114354ed3be2b94d48ddfd88af75ebda7c43cc"}, + {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:012d01df61e009015944ac7543d6ee30c2dc1eb2f6b10b62a3f598beb6531548"}, + {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000c996c53c04770798053e1730d34e30cb645ad95a63265aec82da9093d88e7"}, + {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e0bfe9c4d3429706cf70d3249089de14d6a01192d617e9093a8e941fea8ee775"}, + {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:642a36eb41b6313ffa328e8a5c5c2b5bea6ee138546c9c3cf1bffaad8ee36dd9"}, + {file = "asyncpg-0.29.0-cp38-cp38-win32.whl", hash = "sha256:a921372bbd0aa3a5822dd0409da61b4cd50df89ae85150149f8c119f23e8c408"}, + {file = "asyncpg-0.29.0-cp38-cp38-win_amd64.whl", hash = "sha256:103aad2b92d1506700cbf51cd8bb5441e7e72e87a7b3a2ca4e32c840f051a6a3"}, + {file = "asyncpg-0.29.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5340dd515d7e52f4c11ada32171d87c05570479dc01dc66d03ee3e150fb695da"}, + {file = "asyncpg-0.29.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e17b52c6cf83e170d3d865571ba574577ab8e533e7361a2b8ce6157d02c665d3"}, + {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f100d23f273555f4b19b74a96840aa27b85e99ba4b1f18d4ebff0734e78dc090"}, + {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48e7c58b516057126b363cec8ca02b804644fd012ef8e6c7e23386b7d5e6ce83"}, + {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f9ea3f24eb4c49a615573724d88a48bd1b7821c890c2effe04f05382ed9e8810"}, + {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8d36c7f14a22ec9e928f15f92a48207546ffe68bc412f3be718eedccdf10dc5c"}, + {file = "asyncpg-0.29.0-cp39-cp39-win32.whl", hash = "sha256:797ab8123ebaed304a1fad4d7576d5376c3a006a4100380fb9d517f0b59c1ab2"}, + {file = "asyncpg-0.29.0-cp39-cp39-win_amd64.whl", hash = "sha256:cce08a178858b426ae1aa8409b5cc171def45d4293626e7aa6510696d46decd8"}, + {file = "asyncpg-0.29.0.tar.gz", hash = "sha256:d1c49e1f44fffafd9a55e1a9b101590859d881d639ea2922516f5d9c512d354e"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.12.0\""} + +[package.extras] +docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["flake8 (>=6.1,<7.0)", "uvloop (>=0.15.3)"] + [[package]] name = "attrs" version = "23.1.0" @@ -2016,6 +2088,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -2632,4 +2722,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c37752c01396d4b331396b8d25d7108ca0dee723cfd3b44842ff3b81775894f9" +content-hash = "60b8ce6c40d2b0d4d9c6d306b5ad17ca4cf02951f0a8941808b4603832946991" diff --git a/pyproject.toml b/pyproject.toml index b7d055c..a7c927e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,15 @@ psycopg2-binary = "^2.9.9" pydantic-settings = "^2.0.3" pytz = "^2023.3.post1" g4f = "^0.1.8.1" +asyncpg = "^0.29.0" [tool.poetry.group.dev.dependencies] requests = "^2.31.0" pytest = "7.4.2" black = "^23.10.1" pytest-cov = "^4.1.0" +pytest-asyncio = "^0.21.1" +aiosqlite = "^0.19.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/client.py b/tests/client.py index 0af0ec6..c8d3215 100644 --- a/tests/client.py +++ b/tests/client.py @@ -1,32 +1,31 @@ -from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from fastapi.testclient import TestClient -from sqlalchemy.pool import StaticPool +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine from ifsguid import models from ifsguid.main import app from ifsguid.endpoints import get_db -SQLALCHEMY_DATABASE_URL = "sqlite://" +DATABASE_URL = "sqlite+aiosqlite:///:memory:" -engine = create_engine( - SQLALCHEMY_DATABASE_URL, - connect_args={"check_same_thread": False}, - poolclass=StaticPool, -) -TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +engine: AsyncEngine = create_async_engine(DATABASE_URL) +async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) -# models.Base.metadata.create_all(bind=engine) +async def create_tables(): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.create_all) -def override_get_db() -> TestingSessionLocal: - try: - models.Base.metadata.create_all(bind=engine) - session = TestingSessionLocal() + +async def drop_tables(): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.drop_all) + + +async def override_get_db() -> AsyncSession: + async with async_session() as session: yield session - finally: - session.close() app.dependency_overrides[get_db] = override_get_db diff --git a/tests/conftest.py b/tests/conftest.py index 759adf4..badef90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,13 @@ -import pytest +import pytest_asyncio -from ifsguid import models -from .client import override_get_db, engine +from sqlalchemy.ext.asyncio import AsyncSession +from . import client -@pytest.fixture -def db(): - try: - yield from override_get_db() - finally: - models.Base.metadata.drop_all(bind=engine) + +@pytest_asyncio.fixture() +async def db() -> AsyncSession: + async with client.async_session() as session: + await client.create_tables() + yield session + await client.drop_tables() diff --git a/tests/test_crud.py b/tests/test_crud.py index 76f4507..47cf94b 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -1,33 +1,38 @@ +import pytest + from ifsguid import crud, models, schemas +from . import client ### Unit Tests ### -def test_get_interactions(db): +@pytest.mark.asyncio +async def test_get_interactions(db): interaction1 = models.Interaction( - settings=dict(model_name="model1", role="role1", prompt="prompt1"), + settings=dict(model="model1", role="role1", prompt="prompt1"), ) interaction2 = models.Interaction( - settings=dict(model_name="model2", role="role2", prompt="prompt2"), + settings=dict(model="model2", role="role2", prompt="prompt2"), ) db.add(interaction1) db.add(interaction2) - db.commit() + await db.commit() - interactions = crud.get_interactions(db) + interactions = await crud.get_interactions(db) assert len(interactions) == 2 - assert interactions[0].settings["model_name"] == "model1" - assert interactions[1].settings["model_name"] == "model2" + assert interactions[0].settings["model"] == "model1" + assert interactions[1].settings["model"] == "model2" -def test_get_interaction(db): +@pytest.mark.asyncio +async def test_get_interaction(db): interaction = models.Interaction( - settings=dict(model_name="model", role="role", prompt="prompt"), + settings=dict(model="model", role="role", prompt="prompt"), ) db.add(interaction) - db.commit() + await db.commit() - retrieved_interaction = crud.get_interaction(db, interaction.id) + retrieved_interaction = await crud.get_interaction(db, interaction.id) assert retrieved_interaction.id == interaction.id - assert retrieved_interaction.settings["model_name"] == "model" + assert retrieved_interaction.settings["model"] == "model" diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 1c2ccd0..febc714 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,30 +1,34 @@ +import pytest + from ifsguid import models -from .client import client +from . import client ### Integration Tests ### def test_get_root(): - response = client.get("/api") + response = client.client.get("/api") assert response.status_code == 200 assert response.json() == "Hello from IFSGuid!" -def test_get_all_interactions(db): +@pytest.mark.asyncio +async def test_get_all_interactions(db): interaction1 = models.Interaction(settings={"prompt": "something"}) interaction2 = models.Interaction(settings={"prompt": "something else"}) db.add(interaction1) db.add(interaction2) - db.commit() + await db.commit() - response = client.get("/api/interactions") + response = client.client.get("/api/interactions") assert response.status_code == 200 assert len(response.json()) == 2 -def test_create_interaction(): - response = client.post( +@pytest.mark.asyncio +async def test_create_interaction(db): + response = client.client.post( "/api/interactions", json={ "prompt": "something", @@ -33,6 +37,6 @@ def test_create_interaction(): assert response.status_code == 200 assert response.json()["settings"] == { "prompt": "something", - "model_name": "GPT3", + "model": "gpt-3.5-turbo", "role": "System", }