Skip to content

Commit

Permalink
🔨 update tests to async
Browse files Browse the repository at this point in the history
  • Loading branch information
agn-7 committed Dec 1, 2023
1 parent 119e591 commit 2bd4e2b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 70 deletions.
31 changes: 15 additions & 16 deletions tests/client.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine

from ifsguid import models
from ifsguid.main import app
from ifsguid.endpoints import get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"
DATABASE_URL = "sqlite+aiosqlite:///:memory:"

engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
engine: AsyncEngine = create_async_engine(DATABASE_URL)

async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

# models.Base.metadata.create_all(bind=engine)

async def create_tables():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)

def override_get_db() -> TestingSessionLocal:
try:
models.Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()

async def drop_tables():
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)


async def override_get_db() -> AsyncSession:
async with async_session() as session:
yield session
finally:
session.close()


app.dependency_overrides[get_db] = override_get_db
Expand Down
69 changes: 42 additions & 27 deletions tests/test_crud.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,48 @@
import pytest

from ifsguid import crud, models, schemas
from . import client


### Unit Tests ###


def test_get_interactions(db):
interaction1 = models.Interaction(
settings=dict(model="model1", role="role1", prompt="prompt1"),
)
interaction2 = models.Interaction(
settings=dict(model="model2", role="role2", prompt="prompt2"),
)
db.add(interaction1)
db.add(interaction2)
db.commit()

interactions = crud.get_interactions(db)
assert len(interactions) == 2
assert interactions[0].settings["model"] == "model1"
assert interactions[1].settings["model"] == "model2"


def test_get_interaction(db):
interaction = models.Interaction(
settings=dict(model="model", role="role", prompt="prompt"),
)
db.add(interaction)
db.commit()

retrieved_interaction = crud.get_interaction(db, interaction.id)
assert retrieved_interaction.id == interaction.id
assert retrieved_interaction.settings["model"] == "model"
@pytest.mark.asyncio
async def test_get_interactions(db):
async with client.async_session() as db:
try:
await client.create_tables()
interaction1 = models.Interaction(
settings=dict(model="model1", role="role1", prompt="prompt1"),
)
interaction2 = models.Interaction(
settings=dict(model="model2", role="role2", prompt="prompt2"),
)
db.add(interaction1)
db.add(interaction2)
await db.commit()

interactions = await crud.get_interactions(db)
assert len(interactions) == 2
assert interactions[0].settings["model"] == "model1"
assert interactions[1].settings["model"] == "model2"
finally:
await client.drop_tables()


@pytest.mark.asyncio
async def test_get_interaction(db):
async with client.async_session() as db:
try:
await client.create_tables()
interaction = models.Interaction(
settings=dict(model="model", role="role", prompt="prompt"),
)
db.add(interaction)
await db.commit()

retrieved_interaction = await crud.get_interaction(db, interaction.id)
assert retrieved_interaction.id == interaction.id
assert retrieved_interaction.settings["model"] == "model"
finally:
await client.drop_tables()
68 changes: 41 additions & 27 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,52 @@
import pytest

from ifsguid import models
from .client import client
from . import client


### Integration Tests ###


def test_get_root():
response = client.get("/api")
response = client.client.get("/api")
assert response.status_code == 200
assert response.json() == "Hello from IFSGuid!"


def test_get_all_interactions(db):
interaction1 = models.Interaction(settings={"prompt": "something"})
interaction2 = models.Interaction(settings={"prompt": "something else"})
db.add(interaction1)
db.add(interaction2)
db.commit()

response = client.get("/api/interactions")
assert response.status_code == 200
assert len(response.json()) == 2


def test_create_interaction():
response = client.post(
"/api/interactions",
json={
"prompt": "something",
},
)
assert response.status_code == 200
assert response.json()["settings"] == {
"prompt": "something",
"model": "GPT3",
"role": "System",
}
@pytest.mark.asyncio
async def test_get_all_interactions(db):
async with client.async_session() as db:
try:
await client.create_tables()
interaction1 = models.Interaction(settings={"prompt": "something"})
interaction2 = models.Interaction(settings={"prompt": "something else"})
db.add(interaction1)
db.add(interaction2)
await db.commit()

response = client.client.get("/api/interactions")
assert response.status_code == 200
assert len(response.json()) == 2
finally:
await client.drop_tables()


@pytest.mark.asyncio
async def test_create_interaction():
async with client.async_session() as db:
try:
await client.create_tables()
response = client.client.post(
"/api/interactions",
json={
"prompt": "something",
},
)
assert response.status_code == 200
assert response.json()["settings"] == {
"prompt": "something",
"model": "gpt-3.5-turbo",
"role": "System",
}
finally:
await client.drop_tables()

0 comments on commit 2bd4e2b

Please sign in to comment.