From 3db442345a0846d0c936c7fd3f1fff7212591d46 Mon Sep 17 00:00:00 2001 From: Benyamin Date: Fri, 1 Dec 2023 22:11:07 +0330 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor=20to=20use=20pyte?= =?UTF-8?q?st=5Fasyncio=20for=20the=20fixture?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 8 ++++--- tests/test_crud.py | 48 ++++++++++++++++++++--------------------- tests/test_endpoints.py | 33 ++++++++++++---------------- 3 files changed, 42 insertions(+), 47 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6ed6013..badef90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ -import pytest +import pytest_asyncio + +from sqlalchemy.ext.asyncio import AsyncSession from . import client -@pytest.fixture -async def db(): +@pytest_asyncio.fixture() +async def db() -> AsyncSession: async with client.async_session() as session: await client.create_tables() yield session diff --git a/tests/test_crud.py b/tests/test_crud.py index 0b7d373..47cf94b 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -9,32 +9,30 @@ @pytest.mark.asyncio async def test_get_interactions(db): - async for db in db: # TODO - interaction1 = models.Interaction( - settings=dict(model="model1", role="role1", prompt="prompt1"), - ) - interaction2 = models.Interaction( - settings=dict(model="model2", role="role2", prompt="prompt2"), - ) - db.add(interaction1) - db.add(interaction2) - await db.commit() - - interactions = await crud.get_interactions(db) - assert len(interactions) == 2 - assert interactions[0].settings["model"] == "model1" - assert interactions[1].settings["model"] == "model2" + interaction1 = models.Interaction( + settings=dict(model="model1", role="role1", prompt="prompt1"), + ) + interaction2 = models.Interaction( + settings=dict(model="model2", role="role2", prompt="prompt2"), + ) + db.add(interaction1) + db.add(interaction2) + await db.commit() + + interactions = await crud.get_interactions(db) + assert len(interactions) == 2 + assert interactions[0].settings["model"] == "model1" + assert interactions[1].settings["model"] == "model2" @pytest.mark.asyncio async def test_get_interaction(db): - async for db in db: # TODO - interaction = models.Interaction( - settings=dict(model="model", role="role", prompt="prompt"), - ) - db.add(interaction) - await db.commit() - - retrieved_interaction = await crud.get_interaction(db, interaction.id) - assert retrieved_interaction.id == interaction.id - assert retrieved_interaction.settings["model"] == "model" + interaction = models.Interaction( + settings=dict(model="model", role="role", prompt="prompt"), + ) + db.add(interaction) + await db.commit() + + retrieved_interaction = await crud.get_interaction(db, interaction.id) + assert retrieved_interaction.id == interaction.id + assert retrieved_interaction.settings["model"] == "model" diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 19fb5fe..febc714 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -19,7 +19,7 @@ async def test_get_all_interactions(db): interaction2 = models.Interaction(settings={"prompt": "something else"}) db.add(interaction1) db.add(interaction2) - yield db.commit() # TODO + await db.commit() response = client.client.get("/api/interactions") assert response.status_code == 200 @@ -27,21 +27,16 @@ async def test_get_all_interactions(db): @pytest.mark.asyncio -async def test_create_interaction(): - async with client.async_session() as db: - try: - await client.create_tables() - response = client.client.post( - "/api/interactions", - json={ - "prompt": "something", - }, - ) - assert response.status_code == 200 - assert response.json()["settings"] == { - "prompt": "something", - "model": "gpt-3.5-turbo", - "role": "System", - } - finally: - await client.drop_tables() +async def test_create_interaction(db): + response = client.client.post( + "/api/interactions", + json={ + "prompt": "something", + }, + ) + assert response.status_code == 200 + assert response.json()["settings"] == { + "prompt": "something", + "model": "gpt-3.5-turbo", + "role": "System", + }