From bff2aaae23cd0a584fdd8eb9db3a448e9288fd88 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 00:07:00 +0330 Subject: [PATCH 01/23] =?UTF-8?q?=E2=9E=95=20UPDATE=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- poetry.lock | 59 +++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 7d99f01..ce21d6e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -208,6 +208,63 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "asyncpg" +version = "0.29.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "asyncpg-0.29.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72fd0ef9f00aeed37179c62282a3d14262dbbafb74ec0ba16e1b1864d8a12169"}, + {file = "asyncpg-0.29.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52e8f8f9ff6e21f9b39ca9f8e3e33a5fcdceaf5667a8c5c32bee158e313be385"}, + {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e6823a7012be8b68301342ba33b4740e5a166f6bbda0aee32bc01638491a22"}, + {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:746e80d83ad5d5464cfbf94315eb6744222ab00aa4e522b704322fb182b83610"}, + {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ff8e8109cd6a46ff852a5e6bab8b0a047d7ea42fcb7ca5ae6eaae97d8eacf397"}, + {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:97eb024685b1d7e72b1972863de527c11ff87960837919dac6e34754768098eb"}, + {file = "asyncpg-0.29.0-cp310-cp310-win32.whl", hash = "sha256:5bbb7f2cafd8d1fa3e65431833de2642f4b2124be61a449fa064e1a08d27e449"}, + {file = "asyncpg-0.29.0-cp310-cp310-win_amd64.whl", hash = "sha256:76c3ac6530904838a4b650b2880f8e7af938ee049e769ec2fba7cd66469d7772"}, + {file = "asyncpg-0.29.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4900ee08e85af01adb207519bb4e14b1cae8fd21e0ccf80fac6aa60b6da37b4"}, + {file = "asyncpg-0.29.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a65c1dcd820d5aea7c7d82a3fdcb70e096f8f70d1a8bf93eb458e49bfad036ac"}, + {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b52e46f165585fd6af4863f268566668407c76b2c72d366bb8b522fa66f1870"}, + {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc600ee8ef3dd38b8d67421359779f8ccec30b463e7aec7ed481c8346decf99f"}, + {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:039a261af4f38f949095e1e780bae84a25ffe3e370175193174eb08d3cecab23"}, + {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6feaf2d8f9138d190e5ec4390c1715c3e87b37715cd69b2c3dfca616134efd2b"}, + {file = "asyncpg-0.29.0-cp311-cp311-win32.whl", hash = "sha256:1e186427c88225ef730555f5fdda6c1812daa884064bfe6bc462fd3a71c4b675"}, + {file = "asyncpg-0.29.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfe73ffae35f518cfd6e4e5f5abb2618ceb5ef02a2365ce64f132601000587d3"}, + {file = "asyncpg-0.29.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6011b0dc29886ab424dc042bf9eeb507670a3b40aece3439944006aafe023178"}, + {file = "asyncpg-0.29.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b544ffc66b039d5ec5a7454667f855f7fec08e0dfaf5a5490dfafbb7abbd2cfb"}, + {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d84156d5fb530b06c493f9e7635aa18f518fa1d1395ef240d211cb563c4e2364"}, + {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54858bc25b49d1114178d65a88e48ad50cb2b6f3e475caa0f0c092d5f527c106"}, + {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bde17a1861cf10d5afce80a36fca736a86769ab3579532c03e45f83ba8a09c59"}, + {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:37a2ec1b9ff88d8773d3eb6d3784dc7e3fee7756a5317b67f923172a4748a175"}, + {file = "asyncpg-0.29.0-cp312-cp312-win32.whl", hash = "sha256:bb1292d9fad43112a85e98ecdc2e051602bce97c199920586be83254d9dafc02"}, + {file = "asyncpg-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe"}, + {file = "asyncpg-0.29.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0009a300cae37b8c525e5b449233d59cd9868fd35431abc470a3e364d2b85cb9"}, + {file = "asyncpg-0.29.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cad1324dbb33f3ca0cd2074d5114354ed3be2b94d48ddfd88af75ebda7c43cc"}, + {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:012d01df61e009015944ac7543d6ee30c2dc1eb2f6b10b62a3f598beb6531548"}, + {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000c996c53c04770798053e1730d34e30cb645ad95a63265aec82da9093d88e7"}, + {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e0bfe9c4d3429706cf70d3249089de14d6a01192d617e9093a8e941fea8ee775"}, + {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:642a36eb41b6313ffa328e8a5c5c2b5bea6ee138546c9c3cf1bffaad8ee36dd9"}, + {file = "asyncpg-0.29.0-cp38-cp38-win32.whl", hash = "sha256:a921372bbd0aa3a5822dd0409da61b4cd50df89ae85150149f8c119f23e8c408"}, + {file = "asyncpg-0.29.0-cp38-cp38-win_amd64.whl", hash = "sha256:103aad2b92d1506700cbf51cd8bb5441e7e72e87a7b3a2ca4e32c840f051a6a3"}, + {file = "asyncpg-0.29.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5340dd515d7e52f4c11ada32171d87c05570479dc01dc66d03ee3e150fb695da"}, + {file = "asyncpg-0.29.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e17b52c6cf83e170d3d865571ba574577ab8e533e7361a2b8ce6157d02c665d3"}, + {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f100d23f273555f4b19b74a96840aa27b85e99ba4b1f18d4ebff0734e78dc090"}, + {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48e7c58b516057126b363cec8ca02b804644fd012ef8e6c7e23386b7d5e6ce83"}, + {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f9ea3f24eb4c49a615573724d88a48bd1b7821c890c2effe04f05382ed9e8810"}, + {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8d36c7f14a22ec9e928f15f92a48207546ffe68bc412f3be718eedccdf10dc5c"}, + {file = "asyncpg-0.29.0-cp39-cp39-win32.whl", hash = "sha256:797ab8123ebaed304a1fad4d7576d5376c3a006a4100380fb9d517f0b59c1ab2"}, + {file = "asyncpg-0.29.0-cp39-cp39-win_amd64.whl", hash = "sha256:cce08a178858b426ae1aa8409b5cc171def45d4293626e7aa6510696d46decd8"}, + {file = "asyncpg-0.29.0.tar.gz", hash = "sha256:d1c49e1f44fffafd9a55e1a9b101590859d881d639ea2922516f5d9c512d354e"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.12.0\""} + +[package.extras] +docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["flake8 (>=6.1,<7.0)", "uvloop (>=0.15.3)"] + [[package]] name = "attrs" version = "23.1.0" @@ -2632,4 +2689,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c37752c01396d4b331396b8d25d7108ca0dee723cfd3b44842ff3b81775894f9" +content-hash = "2cc7cafc6637a43ce418bb8971e5b09134f4750d03e4c830b15f26625d1ea28d" diff --git a/pyproject.toml b/pyproject.toml index b7d055c..d4f5b0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ psycopg2-binary = "^2.9.9" pydantic-settings = "^2.0.3" pytz = "^2023.3.post1" g4f = "^0.1.8.1" +asyncpg = "^0.29.0" [tool.poetry.group.dev.dependencies] requests = "^2.31.0" From f046c3031e98670f76fabdd804a55591f9b7ee79 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:02:19 +0330 Subject: [PATCH 02/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20=F0=9F=97=83?= =?UTF-8?q?=EF=B8=8F=20feat(database):=20switch=20to=20asyncpg=20as=20Post?= =?UTF-8?q?greSQL=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alembic/env.py | 4 +++- ifsguid/config.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index ece0603..9022ef9 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -59,7 +59,9 @@ def run_migrations_online() -> None: """ configuration = config.get_section(config.config_ini_section) - configuration["sqlalchemy.url"] = f"{settings.SQLALCHEMY_DATABASE_URI}" + configuration[ + "sqlalchemy.url" + ] = settings.SQLALCHEMY_DATABASE_URI.unicode_string().replace("+asyncpg", "") connectable = engine_from_config( configuration, prefix="sqlalchemy.", diff --git a/ifsguid/config.py b/ifsguid/config.py index 8aefecc..bca64c9 100644 --- a/ifsguid/config.py +++ b/ifsguid/config.py @@ -21,7 +21,7 @@ def assemble_db_connection(cls, v: Optional[str], values: ValidationInfo) -> Any return v print("Creating SQLALCHEMY_DATABASE_URI from .env file ...") return PostgresDsn.build( - scheme="postgresql", + scheme="postgresql+asyncpg", username=values.data.get("POSTGRES_USER"), password=values.data.get("POSTGRES_PASSWORD"), host=values.data.get("POSTGRES_SERVER"), From 5c17846b843475d4b0e36cbac9ab764fdfc6bacf Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:03:01 +0330 Subject: [PATCH 03/23] =?UTF-8?q?=F0=9F=97=83=EF=B8=8F=20feat(database):?= =?UTF-8?q?=20migrate=20to=20async=20SQLAlchemy=20engine?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/database.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ifsguid/database.py b/ifsguid/database.py index b5467af..940091e 100644 --- a/ifsguid/database.py +++ b/ifsguid/database.py @@ -1,9 +1,14 @@ -from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from .config import settings -engine = create_engine( - settings.SQLALCHEMY_DATABASE_URI.unicode_string(), pool_pre_ping=True +engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URI.unicode_string()) +async_session = sessionmaker( + autocommit=False, + autoflush=False, + future=True, + bind=engine, + class_=AsyncSession, + expire_on_commit=False, ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True, bind=engine) From 23190c96709caed1f2748ffbf841694d99be3745 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:06:24 +0330 Subject: [PATCH 04/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20feat(database):=20tr?= =?UTF-8?q?ansition=20to=20asynchronous=20SQLAlchemy=20operations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 23 ++++++++++++++--------- ifsguid/endpoints.py | 23 +++++++++++------------ 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index e91b140..54ed96e 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -1,24 +1,29 @@ from typing import List from uuid import UUID -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, joinedload, selectinload +from sqlalchemy.future import select +from sqlalchemy.ext.asyncio import AsyncSession -from . import models, schemas, utils +from . import models, schemas -def get_interactions( - db: Session, page: int = None, per_page: int = 10 +async def get_interactions( + db: AsyncSession, page: int = None, per_page: int = 10 ) -> List[models.Interaction]: - query = db.query(models.Interaction) + stmt = select(models.Interaction) if page is not None: - query = query.offset((page - 1) * per_page).limit(per_page) + stmt = stmt.offset((page - 1) * per_page).limit(per_page) - return query.all() + result = await db.execute(stmt) + return result.scalars().unique().all() -def get_interaction(db: Session, id: UUID) -> models.Interaction: - return db.query(models.Interaction).filter(models.Interaction.id == id).first() +async def get_interaction(db: AsyncSession, id: UUID) -> models.Interaction: + stmt = select(models.Interaction).where(models.Interaction.id == id) + result = await db.execute(stmt) + return result.scalar() def create_interaction(db: Session, settings: schemas.Settings) -> models.Interaction: diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index dd2019e..967d6cd 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -5,21 +5,19 @@ from sqlalchemy.orm import Session from . import crud, schemas, modules -from .database import SessionLocal +from .database import async_session, AsyncSession router = APIRouter() -def get_db() -> SessionLocal: - try: - db = SessionLocal() - yield db - finally: - db.close() +async def get_db() -> AsyncSession: + async with async_session() as session: + yield session + @router.get("/", response_model=str) -async def get_root(db: Session = Depends(get_db)) -> str: +async def get_root(db: AsyncSession = Depends(get_db)) -> str: return "Hello from IFSGuid!" @@ -27,11 +25,12 @@ async def get_root(db: Session = Depends(get_db)) -> str: async def get_all_interactions( page: Optional[int] = None, per_page: Optional[int] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ) -> List[schemas.Interaction]: + interactions = await crud.get_interactions(db=db, page=page, per_page=per_page) + return [ - schemas.Interaction.model_validate(interaction) - for interaction in crud.get_interactions(db=db, page=page, per_page=per_page) + schemas.Interaction.model_validate(interaction) for interaction in interactions ] @@ -39,7 +38,7 @@ async def get_all_interactions( "/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False ) async def get_interactions( - id: UUID, db: Session = Depends(get_db) + id: UUID, db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" From 13c0e733eb0e6b724599eb2852f8694c2947b2d5 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:07:09 +0330 Subject: [PATCH 05/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20optimized=20the=20qu?= =?UTF-8?q?ery=20using=20SQL=20Join?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index 54ed96e..68a9f79 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -11,7 +11,7 @@ async def get_interactions( db: AsyncSession, page: int = None, per_page: int = 10 ) -> List[models.Interaction]: - stmt = select(models.Interaction) + stmt = select(models.Interaction).options(joinedload(models.Interaction.messages)) if page is not None: stmt = stmt.offset((page - 1) * per_page).limit(per_page) From 3ac0a2a7ed4506251c7824d4586e292f6e86379b Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:08:41 +0330 Subject: [PATCH 06/23] =?UTF-8?q?=F0=9F=97=83=EF=B8=8F=20feat(models):=20o?= =?UTF-8?q?ptimize=20interaction-message=20relationship=20loading?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ifsguid/models.py b/ifsguid/models.py index 7c67ea0..62e1251 100644 --- a/ifsguid/models.py +++ b/ifsguid/models.py @@ -53,7 +53,7 @@ class Interaction(Base): ) settings = Column(JSON) - messages = relationship("Message", back_populates="interaction") + messages = relationship("Message", back_populates="interaction", lazy="selectin") class Message(Base): From 8afc5861cd49e9c7e2cda6f456b083ccb89e6b3d Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:11:05 +0330 Subject: [PATCH 07/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20feat(crud):=20conver?= =?UTF-8?q?t=20create=5Finteraction=20to=20async=20and=20add=20interaction?= =?UTF-8?q?=20retrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 13 +++++++++++-- ifsguid/endpoints.py | 8 ++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index 68a9f79..725b7cf 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -26,12 +26,21 @@ async def get_interaction(db: AsyncSession, id: UUID) -> models.Interaction: return result.scalar() -def create_interaction(db: Session, settings: schemas.Settings) -> models.Interaction: +async def create_interaction( + db: AsyncSession, settings: schemas.Settings +) -> models.Interaction: interaction = models.Interaction( settings=settings.model_dump(), ) db.add(interaction) - db.commit() + await db.commit() + result = await db.scalars( + select(models.Interaction) + .options(selectinload(models.Interaction.messages)) + .where(models.Interaction.id == interaction.id) + ) + interaction = result.first() + return interaction diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index 967d6cd..50608a7 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -47,11 +47,11 @@ async def get_interactions( @router.post("/interactions", response_model=schemas.Interaction) async def create_interactions( - settings: schemas.Settings, db: Session = Depends(get_db) + settings: schemas.Settings, db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: - return schemas.Interaction.model_validate( - crud.create_interaction(db=db, settings=settings) - ) + interaction = await crud.create_interaction(db=db, settings=settings) + + return schemas.Interaction.model_validate(interaction) @router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False) From 0e095f5429ef1d083ac95ed1cba47e7a47feee0a Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:12:10 +0330 Subject: [PATCH 08/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20changed=20the=20quer?= =?UTF-8?q?y=20from=20selectin=20to=20joinedload?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index 725b7cf..757f50f 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -34,12 +34,12 @@ async def create_interaction( ) db.add(interaction) await db.commit() - result = await db.scalars( + result = await db.execute( select(models.Interaction) - .options(selectinload(models.Interaction.messages)) + .options(joinedload(models.Interaction.messages)) .where(models.Interaction.id == interaction.id) ) - interaction = result.first() + interaction = result.scalars().unique().one() return interaction From 9a9658e12c4f7be38c148b63113795cc4bf37657 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 17:13:43 +0330 Subject: [PATCH 09/23] =?UTF-8?q?=F0=9F=A9=B9=20simplified=20the=20query?= =?UTF-8?q?=20using=20selectin=20within=20the=20relationship=20itself=20in?= =?UTF-8?q?=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 7 +------ ifsguid/endpoints.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index 757f50f..edc0f47 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -34,12 +34,7 @@ async def create_interaction( ) db.add(interaction) await db.commit() - result = await db.execute( - select(models.Interaction) - .options(joinedload(models.Interaction.messages)) - .where(models.Interaction.id == interaction.id) - ) - interaction = result.scalars().unique().one() + await db.refresh(interaction) return interaction diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index 50608a7..d6a63ac 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -15,7 +15,6 @@ async def get_db() -> AsyncSession: yield session - @router.get("/", response_model=str) async def get_root(db: AsyncSession = Depends(get_db)) -> str: return "Hello from IFSGuid!" From ae493c2a3da217425bba27f1142dc8c8186d2875 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 23:48:45 +0330 Subject: [PATCH 10/23] =?UTF-8?q?=F0=9F=A6=BA=20update=20settings=20schema?= =?UTF-8?q?=20to=20support=20all=20models?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/endpoints.py | 7 ++++++- ifsguid/schemas.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index d6a63ac..bc1c15d 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -46,8 +46,13 @@ async def get_interactions( @router.post("/interactions", response_model=schemas.Interaction) async def create_interactions( - settings: schemas.Settings, db: AsyncSession = Depends(get_db) + prompt: schemas.Prompt, + chat_model: schemas.ChatModel = Depends(), + db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: + settings = schemas.Settings( + model=chat_model.model, prompt=prompt.prompt, role=prompt.role + ) interaction = await crud.create_interaction(db=db, settings=settings) return schemas.Interaction.model_validate(interaction) diff --git a/ifsguid/schemas.py b/ifsguid/schemas.py index de6b06d..fe2d31c 100644 --- a/ifsguid/schemas.py +++ b/ifsguid/schemas.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, ConfigDict from typing import List, Literal +from g4f import _all_models + class MessageCreate(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -18,11 +20,17 @@ class Message(MessageCreate): created_at: datetime -class Settings(BaseModel): - model_name: Literal["GPT4", "GPT3"] = "GPT3" +class ChatModel(BaseModel): + model: Literal[tuple(_all_models)] = "gpt-3.5-turbo" + + +class Prompt(BaseModel): role: Literal["System"] = "System" prompt: str +class Settings(ChatModel, Prompt): + pass + class InteractionCreate(BaseModel): model_config = ConfigDict(from_attributes=True) From e93531807dcbb92e1ff1109ab43bbd593b4f8301 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Wed, 29 Nov 2023 23:50:41 +0330 Subject: [PATCH 11/23] =?UTF-8?q?=F0=9F=90=9B=20fixed=20model=20typo=20nam?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_crud.py | 12 ++++++------ tests/test_endpoints.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_crud.py b/tests/test_crud.py index 76f4507..ac5d90b 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -6,10 +6,10 @@ def test_get_interactions(db): interaction1 = models.Interaction( - settings=dict(model_name="model1", role="role1", prompt="prompt1"), + settings=dict(model="model1", role="role1", prompt="prompt1"), ) interaction2 = models.Interaction( - settings=dict(model_name="model2", role="role2", prompt="prompt2"), + settings=dict(model="model2", role="role2", prompt="prompt2"), ) db.add(interaction1) db.add(interaction2) @@ -17,17 +17,17 @@ def test_get_interactions(db): interactions = crud.get_interactions(db) assert len(interactions) == 2 - assert interactions[0].settings["model_name"] == "model1" - assert interactions[1].settings["model_name"] == "model2" + assert interactions[0].settings["model"] == "model1" + assert interactions[1].settings["model"] == "model2" def test_get_interaction(db): interaction = models.Interaction( - settings=dict(model_name="model", role="role", prompt="prompt"), + settings=dict(model="model", role="role", prompt="prompt"), ) db.add(interaction) db.commit() retrieved_interaction = crud.get_interaction(db, interaction.id) assert retrieved_interaction.id == interaction.id - assert retrieved_interaction.settings["model_name"] == "model" + assert retrieved_interaction.settings["model"] == "model" diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 1c2ccd0..be349d0 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -33,6 +33,6 @@ def test_create_interaction(): assert response.status_code == 200 assert response.json()["settings"] == { "prompt": "something", - "model_name": "GPT3", + "model": "GPT3", "role": "System", } From 20891452329e94b783fb6dfd30bcf4365abc7a43 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Thu, 30 Nov 2023 00:28:20 +0330 Subject: [PATCH 12/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20made=20async?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/modules.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ifsguid/modules.py b/ifsguid/modules.py index b614674..861f73c 100644 --- a/ifsguid/modules.py +++ b/ifsguid/modules.py @@ -1,13 +1,19 @@ import g4f +import traceback g4f.debug.logging = True g4f.check_version = False -def generate_ai_response( - content: str, model: str = "GPT4" -) -> str: # TODO: use async version for future - return g4f.ChatCompletion.create( - model=g4f.models.gpt_4 if model == "GPT4" else g4f.models.gpt_35_turbo, - messages=[{"role": "human", "content": content}], - ) +async def generate_ai_response( + content: str, model: g4f.Model = g4f.models.default +) -> str: + try: + response = await g4f.ChatCompletion.create_async( + model=model, + messages=[{"role": "human", "content": content}], + ) + return response + except Exception: + traceback.print_exc() + return "error!" From ce3d18409c9d9c6008977ee46fea7061e770c91a Mon Sep 17 00:00:00 2001 From: Benyamin Date: Thu, 30 Nov 2023 00:32:18 +0330 Subject: [PATCH 13/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Switch=20to=20async?= =?UTF-8?q?=20database=20operations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 16 ++++++++-------- ifsguid/endpoints.py | 31 ++++++++++++++++--------------- ifsguid/schemas.py | 1 + 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index edc0f47..e965484 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -39,7 +39,7 @@ async def create_interaction( return interaction -def delete_interaction(db: Session, id: UUID) -> None: +async def delete_interaction(db: AsyncSession, id: UUID) -> None: interaction = ( db.query(models.Interaction).filter(models.Interaction.id == id).first() ) @@ -52,8 +52,8 @@ def delete_interaction(db: Session, id: UUID) -> None: return False -def update_interaction( - db: Session, id: UUID, settings: schemas.Settings +async def update_interaction( + db: AsyncSession, id: UUID, settings: schemas.Settings ) -> models.Interaction: interaction: models.Interaction = ( db.query(models.Interaction).filter(models.Interaction.id == id).first() @@ -67,8 +67,8 @@ def update_interaction( return None -def get_messages( - db: Session, interaction_id: UUID = None, page: int = None, per_page: int = 10 +async def get_messages( + db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10 ) -> List[models.Message]: query = db.query(models.Message) @@ -81,8 +81,8 @@ def get_messages( return query.all() -def create_message( - db: Session, messages: List[schemas.MessageCreate], interaction_id: UUID +async def create_message( + db: AsyncSession, messages: List[schemas.MessageCreate], interaction_id: UUID ) -> List[models.Message]: messages_db = [] for msg in messages: @@ -93,5 +93,5 @@ def create_message( db.add(message) messages_db.append(message) - db.commit() + await db.commit() return messages_db diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index bc1c15d..2329884 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -2,7 +2,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session +from g4f import ModelUtils from . import crud, schemas, modules from .database import async_session, AsyncSession @@ -48,7 +48,7 @@ async def get_interactions( async def create_interactions( prompt: schemas.Prompt, chat_model: schemas.ChatModel = Depends(), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> schemas.Interaction: settings = schemas.Settings( model=chat_model.model, prompt=prompt.prompt, role=prompt.role @@ -59,7 +59,7 @@ async def create_interactions( @router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False) -async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None: +async def delete_interaction(id: UUID, db: AsyncSession = Depends(get_db)) -> None: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" ) @@ -69,7 +69,7 @@ async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None: "/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False ) async def update_interaction( - id: UUID, settings: schemas.Settings, db: Session = Depends(get_db) + id: UUID, settings: schemas.Settings, db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" @@ -83,7 +83,7 @@ async def get_all_message_in_interaction( interaction_id: UUID, page: Optional[int] = None, per_page: Optional[int] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ) -> List[schemas.Message]: interaction = crud.get_interaction(db=db, id=str(interaction_id)) @@ -104,9 +104,11 @@ async def get_all_message_in_interaction( "/interactions/{interactions_id}/messages", response_model=List[schemas.Message] ) async def create_message( - interaction_id: UUID, message: schemas.MessageCreate, db: Session = Depends(get_db) + interaction_id: UUID, + message: schemas.MessageCreate, + db: AsyncSession = Depends(get_db), ) -> schemas.Message: - interaction = crud.get_interaction(db=db, id=str(interaction_id)) + interaction = await crud.get_interaction(db=db, id=str(interaction_id)) if not interaction: raise HTTPException( @@ -117,17 +119,16 @@ async def create_message( messages = [] if message.role == "human": - ai_content = modules.generate_ai_response( - content=message.content, model=interaction.settings.model_name + ai_content = await modules.generate_ai_response( + content=message.content, + model=ModelUtils.convert[interaction.settings.model], ) ai_message = schemas.MessageCreate(role="ai", content=ai_content) messages.append(message) messages.append(ai_message) - return [ - schemas.Message.model_validate(message) - for message in crud.create_message( - db=db, messages=messages, interaction_id=str(interaction_id) - ) - ] + messages = await crud.create_message( + db=db, messages=messages, interaction_id=str(interaction_id) + ) + return [schemas.Message.model_validate(message) for message in messages] diff --git a/ifsguid/schemas.py b/ifsguid/schemas.py index fe2d31c..c005748 100644 --- a/ifsguid/schemas.py +++ b/ifsguid/schemas.py @@ -28,6 +28,7 @@ class Prompt(BaseModel): role: Literal["System"] = "System" prompt: str + class Settings(ChatModel, Prompt): pass From ca65ce5ba3e6afa73294574fc992a4832a263418 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Thu, 30 Nov 2023 11:37:23 +0330 Subject: [PATCH 14/23] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20update=20methods=20t?= =?UTF-8?q?o=20use=20SQLAlchemy=20Core=20and=20async?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/crud.py | 36 +++++++++++++++++++----------------- ifsguid/endpoints.py | 13 ++++++------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/ifsguid/crud.py b/ifsguid/crud.py index e965484..05225c8 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -1,7 +1,8 @@ from typing import List from uuid import UUID -from sqlalchemy.orm import Session, joinedload, selectinload +from sqlalchemy import delete, update +from sqlalchemy.orm import joinedload from sqlalchemy.future import select from sqlalchemy.ext.asyncio import AsyncSession @@ -40,13 +41,11 @@ async def create_interaction( async def delete_interaction(db: AsyncSession, id: UUID) -> None: - interaction = ( - db.query(models.Interaction).filter(models.Interaction.id == id).first() - ) + stmt = delete(models.Interaction).where(models.Interaction.id == id) + result = await db.execute(stmt) - if interaction is not None: - db.delete(interaction) - db.commit() + if result.rowcount: + await db.commit() return True return False @@ -55,14 +54,16 @@ async def delete_interaction(db: AsyncSession, id: UUID) -> None: async def update_interaction( db: AsyncSession, id: UUID, settings: schemas.Settings ) -> models.Interaction: - interaction: models.Interaction = ( - db.query(models.Interaction).filter(models.Interaction.id == id).first() + stmt = ( + update(models.Interaction) + .where(models.Interaction.id == id) + .values(settings=settings) ) + result = await db.execute(stmt) - if interaction is not None: - interaction.settings = settings - db.commit() - return interaction + if result.rowcount: + await db.commit() + return True return None @@ -70,15 +71,16 @@ async def update_interaction( async def get_messages( db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10 ) -> List[models.Message]: - query = db.query(models.Message) + stmt = select(models.Message) if interaction_id is not None: - query = query.filter(models.Message.interaction_id == interaction_id) + stmt = stmt.where(models.Message.interaction_id == interaction_id) if page is not None: - query = query.offset((page - 1) * per_page).limit(per_page) + stmt = stmt.offset((page - 1) * per_page).limit(per_page) - return query.all() + result = await db.execute(stmt) + return result.scalars().all() async def create_message( diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index 2329884..b16f4b4 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -85,19 +85,18 @@ async def get_all_message_in_interaction( per_page: Optional[int] = None, db: AsyncSession = Depends(get_db), ) -> List[schemas.Message]: - interaction = crud.get_interaction(db=db, id=str(interaction_id)) + interaction = await crud.get_interaction(db=db, id=str(interaction_id)) if not interaction: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Interaction not found" ) - return [ - schemas.Message.model_validate(message) - for message in crud.get_messages( - db=db, interaction_id=str(interaction_id), page=page, per_page=per_page - ) - ] + messages = await crud.get_messages( + db=db, interaction_id=str(interaction_id), page=page, per_page=per_page + ) + + return [schemas.Message.model_validate(message) for message in messages] @router.post( From d17e405fc0006c0306e6950edd71757bc299eca2 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 17:34:44 +0330 Subject: [PATCH 15/23] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20add=20hinting=20t?= =?UTF-8?q?ype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ifsguid/database.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ifsguid/database.py b/ifsguid/database.py index 940091e..27da591 100644 --- a/ifsguid/database.py +++ b/ifsguid/database.py @@ -1,9 +1,11 @@ -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine from sqlalchemy.orm import sessionmaker from .config import settings -engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URI.unicode_string()) +engine: AsyncEngine = create_async_engine( + settings.SQLALCHEMY_DATABASE_URI.unicode_string() +) async_session = sessionmaker( autocommit=False, autoflush=False, From 119e591df7d2fcc54952a311c36a3e54938690ff Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 17:35:20 +0330 Subject: [PATCH 16/23] =?UTF-8?q?=E2=9E=95=20UPDATE=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- poetry.lock | 35 ++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index ce21d6e..6585efc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -133,6 +133,21 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosqlite" +version = "0.19.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"}, + {file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"}, +] + +[package.extras] +dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] +docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"] + [[package]] name = "alembic" version = "1.12.1" @@ -2073,6 +2088,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -2689,4 +2722,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "2cc7cafc6637a43ce418bb8971e5b09134f4750d03e4c830b15f26625d1ea28d" +content-hash = "60b8ce6c40d2b0d4d9c6d306b5ad17ca4cf02951f0a8941808b4603832946991" diff --git a/pyproject.toml b/pyproject.toml index d4f5b0e..a7c927e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ requests = "^2.31.0" pytest = "7.4.2" black = "^23.10.1" pytest-cov = "^4.1.0" +pytest-asyncio = "^0.21.1" +aiosqlite = "^0.19.0" [build-system] requires = ["poetry-core>=1.0.0"] From 2bd4e2b7be7fdcaa89a39960c814a1a78adf6f60 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 17:37:24 +0330 Subject: [PATCH 17/23] =?UTF-8?q?=F0=9F=94=A8=20update=20tests=20to=20asyn?= =?UTF-8?q?c?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/client.py | 31 +++++++++--------- tests/test_crud.py | 69 +++++++++++++++++++++++++---------------- tests/test_endpoints.py | 68 ++++++++++++++++++++++++---------------- 3 files changed, 98 insertions(+), 70 deletions(-) diff --git a/tests/client.py b/tests/client.py index 0af0ec6..c8d3215 100644 --- a/tests/client.py +++ b/tests/client.py @@ -1,32 +1,31 @@ -from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from fastapi.testclient import TestClient -from sqlalchemy.pool import StaticPool +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine from ifsguid import models from ifsguid.main import app from ifsguid.endpoints import get_db -SQLALCHEMY_DATABASE_URL = "sqlite://" +DATABASE_URL = "sqlite+aiosqlite:///:memory:" -engine = create_engine( - SQLALCHEMY_DATABASE_URL, - connect_args={"check_same_thread": False}, - poolclass=StaticPool, -) -TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +engine: AsyncEngine = create_async_engine(DATABASE_URL) +async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) -# models.Base.metadata.create_all(bind=engine) +async def create_tables(): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.create_all) -def override_get_db() -> TestingSessionLocal: - try: - models.Base.metadata.create_all(bind=engine) - session = TestingSessionLocal() + +async def drop_tables(): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.drop_all) + + +async def override_get_db() -> AsyncSession: + async with async_session() as session: yield session - finally: - session.close() app.dependency_overrides[get_db] = override_get_db diff --git a/tests/test_crud.py b/tests/test_crud.py index ac5d90b..00c5b28 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -1,33 +1,48 @@ +import pytest + from ifsguid import crud, models, schemas +from . import client ### Unit Tests ### -def test_get_interactions(db): - interaction1 = models.Interaction( - settings=dict(model="model1", role="role1", prompt="prompt1"), - ) - interaction2 = models.Interaction( - settings=dict(model="model2", role="role2", prompt="prompt2"), - ) - db.add(interaction1) - db.add(interaction2) - db.commit() - - interactions = crud.get_interactions(db) - assert len(interactions) == 2 - assert interactions[0].settings["model"] == "model1" - assert interactions[1].settings["model"] == "model2" - - -def test_get_interaction(db): - interaction = models.Interaction( - settings=dict(model="model", role="role", prompt="prompt"), - ) - db.add(interaction) - db.commit() - - retrieved_interaction = crud.get_interaction(db, interaction.id) - assert retrieved_interaction.id == interaction.id - assert retrieved_interaction.settings["model"] == "model" +@pytest.mark.asyncio +async def test_get_interactions(db): + async with client.async_session() as db: + try: + await client.create_tables() + interaction1 = models.Interaction( + settings=dict(model="model1", role="role1", prompt="prompt1"), + ) + interaction2 = models.Interaction( + settings=dict(model="model2", role="role2", prompt="prompt2"), + ) + db.add(interaction1) + db.add(interaction2) + await db.commit() + + interactions = await crud.get_interactions(db) + assert len(interactions) == 2 + assert interactions[0].settings["model"] == "model1" + assert interactions[1].settings["model"] == "model2" + finally: + await client.drop_tables() + + +@pytest.mark.asyncio +async def test_get_interaction(db): + async with client.async_session() as db: + try: + await client.create_tables() + interaction = models.Interaction( + settings=dict(model="model", role="role", prompt="prompt"), + ) + db.add(interaction) + await db.commit() + + retrieved_interaction = await crud.get_interaction(db, interaction.id) + assert retrieved_interaction.id == interaction.id + assert retrieved_interaction.settings["model"] == "model" + finally: + await client.drop_tables() diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index be349d0..95653d4 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,38 +1,52 @@ +import pytest + from ifsguid import models -from .client import client +from . import client ### Integration Tests ### def test_get_root(): - response = client.get("/api") + response = client.client.get("/api") assert response.status_code == 200 assert response.json() == "Hello from IFSGuid!" -def test_get_all_interactions(db): - interaction1 = models.Interaction(settings={"prompt": "something"}) - interaction2 = models.Interaction(settings={"prompt": "something else"}) - db.add(interaction1) - db.add(interaction2) - db.commit() - - response = client.get("/api/interactions") - assert response.status_code == 200 - assert len(response.json()) == 2 - - -def test_create_interaction(): - response = client.post( - "/api/interactions", - json={ - "prompt": "something", - }, - ) - assert response.status_code == 200 - assert response.json()["settings"] == { - "prompt": "something", - "model": "GPT3", - "role": "System", - } +@pytest.mark.asyncio +async def test_get_all_interactions(db): + async with client.async_session() as db: + try: + await client.create_tables() + interaction1 = models.Interaction(settings={"prompt": "something"}) + interaction2 = models.Interaction(settings={"prompt": "something else"}) + db.add(interaction1) + db.add(interaction2) + await db.commit() + + response = client.client.get("/api/interactions") + assert response.status_code == 200 + assert len(response.json()) == 2 + finally: + await client.drop_tables() + + +@pytest.mark.asyncio +async def test_create_interaction(): + async with client.async_session() as db: + try: + await client.create_tables() + response = client.client.post( + "/api/interactions", + json={ + "prompt": "something", + }, + ) + assert response.status_code == 200 + assert response.json()["settings"] == { + "prompt": "something", + "model": "gpt-3.5-turbo", + "role": "System", + } + finally: + await client.drop_tables() From 8ece2d9700cb9458911b6f05a3a28dcabeef1ae3 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 17:45:55 +0330 Subject: [PATCH 18/23] =?UTF-8?q?=F0=9F=94=A8=20update=20to=20async=20vers?= =?UTF-8?q?ion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 759adf4..9612052 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,9 @@ import pytest -from ifsguid import models from .client import override_get_db, engine @pytest.fixture -def db(): - try: - yield from override_get_db() - finally: - models.Base.metadata.drop_all(bind=engine) +async def db(): + async for session in override_get_db(): + yield session From 0cd62ab0718786b46cf416b23215d824b68e2d31 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 17:59:44 +0330 Subject: [PATCH 19/23] =?UTF-8?q?=F0=9F=90=9B=20fixed=20postgresql=20async?= =?UTF-8?q?=20url?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/github-actions.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index fc1eea4..853d755 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -36,12 +36,12 @@ jobs: run: | poetry run pytest -s env: - SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db + SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db - name: Code Coverage run: | poetry run pytest --cov=./ifsguid --cov-report=xml --cov-report=term-missing env: - SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db + SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 env: From c5490f62371857673f60acf7494f5dd0d9d8c191 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 18:13:57 +0330 Subject: [PATCH 20/23] =?UTF-8?q?=F0=9F=91=B7=20add=20codecov=20policy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .codecov.yml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..7687239 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,7 @@ +coverage: + status: + project: + default: + # basic + target: auto + threshold: null From 9c52bfed2098639d6b2cfbf30b8898e3f545f950 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 18:33:40 +0330 Subject: [PATCH 21/23] =?UTF-8?q?=F0=9F=93=9D=20=20update=20SQLALCHEMY=5FD?= =?UTF-8?q?ATABASE=5FURI=20to=20use=20asyncpg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 81813bd..fc8fce2 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,5 @@ To launch an API instance, you should: You can also run the project via `docker-compose` (i.e. `docker compose up -d`) on port `80` in which you would need the [.docker.env](/.docker.env) containing the following variable to create the database: ``` -SQLALCHEMY_DATABASE_URI=postgresql://:@ifsguid_db/ +SQLALCHEMY_DATABASE_URI=postgresql+asyncpg://:@ifsguid_db/ ``` - - From 5424607b06e01f215fd8c3e820bbed81f88b1565 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 19:54:43 +0330 Subject: [PATCH 22/23] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor=20pytest=20?= =?UTF-8?q?fixture?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 6 +++-- tests/test_crud.py | 58 ++++++++++++++++++----------------------- tests/test_endpoints.py | 21 ++++++--------- 3 files changed, 37 insertions(+), 48 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9612052..6ed6013 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ import pytest -from .client import override_get_db, engine +from . import client @pytest.fixture async def db(): - async for session in override_get_db(): + async with client.async_session() as session: + await client.create_tables() yield session + await client.drop_tables() diff --git a/tests/test_crud.py b/tests/test_crud.py index 00c5b28..0b7d373 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -9,40 +9,32 @@ @pytest.mark.asyncio async def test_get_interactions(db): - async with client.async_session() as db: - try: - await client.create_tables() - interaction1 = models.Interaction( - settings=dict(model="model1", role="role1", prompt="prompt1"), - ) - interaction2 = models.Interaction( - settings=dict(model="model2", role="role2", prompt="prompt2"), - ) - db.add(interaction1) - db.add(interaction2) - await db.commit() - - interactions = await crud.get_interactions(db) - assert len(interactions) == 2 - assert interactions[0].settings["model"] == "model1" - assert interactions[1].settings["model"] == "model2" - finally: - await client.drop_tables() + async for db in db: # TODO + interaction1 = models.Interaction( + settings=dict(model="model1", role="role1", prompt="prompt1"), + ) + interaction2 = models.Interaction( + settings=dict(model="model2", role="role2", prompt="prompt2"), + ) + db.add(interaction1) + db.add(interaction2) + await db.commit() + + interactions = await crud.get_interactions(db) + assert len(interactions) == 2 + assert interactions[0].settings["model"] == "model1" + assert interactions[1].settings["model"] == "model2" @pytest.mark.asyncio async def test_get_interaction(db): - async with client.async_session() as db: - try: - await client.create_tables() - interaction = models.Interaction( - settings=dict(model="model", role="role", prompt="prompt"), - ) - db.add(interaction) - await db.commit() - - retrieved_interaction = await crud.get_interaction(db, interaction.id) - assert retrieved_interaction.id == interaction.id - assert retrieved_interaction.settings["model"] == "model" - finally: - await client.drop_tables() + async for db in db: # TODO + interaction = models.Interaction( + settings=dict(model="model", role="role", prompt="prompt"), + ) + db.add(interaction) + await db.commit() + + retrieved_interaction = await crud.get_interaction(db, interaction.id) + assert retrieved_interaction.id == interaction.id + assert retrieved_interaction.settings["model"] == "model" diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 95653d4..19fb5fe 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -15,20 +15,15 @@ def test_get_root(): @pytest.mark.asyncio async def test_get_all_interactions(db): - async with client.async_session() as db: - try: - await client.create_tables() - interaction1 = models.Interaction(settings={"prompt": "something"}) - interaction2 = models.Interaction(settings={"prompt": "something else"}) - db.add(interaction1) - db.add(interaction2) - await db.commit() + interaction1 = models.Interaction(settings={"prompt": "something"}) + interaction2 = models.Interaction(settings={"prompt": "something else"}) + db.add(interaction1) + db.add(interaction2) + yield db.commit() # TODO - response = client.client.get("/api/interactions") - assert response.status_code == 200 - assert len(response.json()) == 2 - finally: - await client.drop_tables() + response = client.client.get("/api/interactions") + assert response.status_code == 200 + assert len(response.json()) == 2 @pytest.mark.asyncio From 3db442345a0846d0c936c7fd3f1fff7212591d46 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 22:11:07 +0330 Subject: [PATCH 23/23] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor=20to=20use?= =?UTF-8?q?=20pytest=5Fasyncio=20for=20the=20fixture?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 8 ++++--- tests/test_crud.py | 48 ++++++++++++++++++++--------------------- tests/test_endpoints.py | 33 ++++++++++++---------------- 3 files changed, 42 insertions(+), 47 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6ed6013..badef90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ -import pytest +import pytest_asyncio + +from sqlalchemy.ext.asyncio import AsyncSession from . import client -@pytest.fixture -async def db(): +@pytest_asyncio.fixture() +async def db() -> AsyncSession: async with client.async_session() as session: await client.create_tables() yield session diff --git a/tests/test_crud.py b/tests/test_crud.py index 0b7d373..47cf94b 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -9,32 +9,30 @@ @pytest.mark.asyncio async def test_get_interactions(db): - async for db in db: # TODO - interaction1 = models.Interaction( - settings=dict(model="model1", role="role1", prompt="prompt1"), - ) - interaction2 = models.Interaction( - settings=dict(model="model2", role="role2", prompt="prompt2"), - ) - db.add(interaction1) - db.add(interaction2) - await db.commit() - - interactions = await crud.get_interactions(db) - assert len(interactions) == 2 - assert interactions[0].settings["model"] == "model1" - assert interactions[1].settings["model"] == "model2" + interaction1 = models.Interaction( + settings=dict(model="model1", role="role1", prompt="prompt1"), + ) + interaction2 = models.Interaction( + settings=dict(model="model2", role="role2", prompt="prompt2"), + ) + db.add(interaction1) + db.add(interaction2) + await db.commit() + + interactions = await crud.get_interactions(db) + assert len(interactions) == 2 + assert interactions[0].settings["model"] == "model1" + assert interactions[1].settings["model"] == "model2" @pytest.mark.asyncio async def test_get_interaction(db): - async for db in db: # TODO - interaction = models.Interaction( - settings=dict(model="model", role="role", prompt="prompt"), - ) - db.add(interaction) - await db.commit() - - retrieved_interaction = await crud.get_interaction(db, interaction.id) - assert retrieved_interaction.id == interaction.id - assert retrieved_interaction.settings["model"] == "model" + interaction = models.Interaction( + settings=dict(model="model", role="role", prompt="prompt"), + ) + db.add(interaction) + await db.commit() + + retrieved_interaction = await crud.get_interaction(db, interaction.id) + assert retrieved_interaction.id == interaction.id + assert retrieved_interaction.settings["model"] == "model" diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 19fb5fe..febc714 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -19,7 +19,7 @@ async def test_get_all_interactions(db): interaction2 = models.Interaction(settings={"prompt": "something else"}) db.add(interaction1) db.add(interaction2) - yield db.commit() # TODO + await db.commit() response = client.client.get("/api/interactions") assert response.status_code == 200 @@ -27,21 +27,16 @@ async def test_get_all_interactions(db): @pytest.mark.asyncio -async def test_create_interaction(): - async with client.async_session() as db: - try: - await client.create_tables() - response = client.client.post( - "/api/interactions", - json={ - "prompt": "something", - }, - ) - assert response.status_code == 200 - assert response.json()["settings"] == { - "prompt": "something", - "model": "gpt-3.5-turbo", - "role": "System", - } - finally: - await client.drop_tables() +async def test_create_interaction(db): + response = client.client.post( + "/api/interactions", + json={ + "prompt": "something", + }, + ) + assert response.status_code == 200 + assert response.json()["settings"] == { + "prompt": "something", + "model": "gpt-3.5-turbo", + "role": "System", + }