-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
98 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,48 @@ | ||
import pytest | ||
|
||
from ifsguid import crud, models, schemas | ||
from . import client | ||
|
||
|
||
### Unit Tests ### | ||
|
||
|
||
def test_get_interactions(db): | ||
interaction1 = models.Interaction( | ||
settings=dict(model="model1", role="role1", prompt="prompt1"), | ||
) | ||
interaction2 = models.Interaction( | ||
settings=dict(model="model2", role="role2", prompt="prompt2"), | ||
) | ||
db.add(interaction1) | ||
db.add(interaction2) | ||
db.commit() | ||
|
||
interactions = crud.get_interactions(db) | ||
assert len(interactions) == 2 | ||
assert interactions[0].settings["model"] == "model1" | ||
assert interactions[1].settings["model"] == "model2" | ||
|
||
|
||
def test_get_interaction(db): | ||
interaction = models.Interaction( | ||
settings=dict(model="model", role="role", prompt="prompt"), | ||
) | ||
db.add(interaction) | ||
db.commit() | ||
|
||
retrieved_interaction = crud.get_interaction(db, interaction.id) | ||
assert retrieved_interaction.id == interaction.id | ||
assert retrieved_interaction.settings["model"] == "model" | ||
@pytest.mark.asyncio | ||
async def test_get_interactions(db): | ||
async with client.async_session() as db: | ||
try: | ||
await client.create_tables() | ||
interaction1 = models.Interaction( | ||
settings=dict(model="model1", role="role1", prompt="prompt1"), | ||
) | ||
interaction2 = models.Interaction( | ||
settings=dict(model="model2", role="role2", prompt="prompt2"), | ||
) | ||
db.add(interaction1) | ||
db.add(interaction2) | ||
await db.commit() | ||
|
||
interactions = await crud.get_interactions(db) | ||
assert len(interactions) == 2 | ||
assert interactions[0].settings["model"] == "model1" | ||
assert interactions[1].settings["model"] == "model2" | ||
finally: | ||
await client.drop_tables() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_interaction(db): | ||
async with client.async_session() as db: | ||
try: | ||
await client.create_tables() | ||
interaction = models.Interaction( | ||
settings=dict(model="model", role="role", prompt="prompt"), | ||
) | ||
db.add(interaction) | ||
await db.commit() | ||
|
||
retrieved_interaction = await crud.get_interaction(db, interaction.id) | ||
assert retrieved_interaction.id == interaction.id | ||
assert retrieved_interaction.settings["model"] == "model" | ||
finally: | ||
await client.drop_tables() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,52 @@ | ||
import pytest | ||
|
||
from ifsguid import models | ||
from .client import client | ||
from . import client | ||
|
||
|
||
### Integration Tests ### | ||
|
||
|
||
def test_get_root(): | ||
response = client.get("/api") | ||
response = client.client.get("/api") | ||
assert response.status_code == 200 | ||
assert response.json() == "Hello from IFSGuid!" | ||
|
||
|
||
def test_get_all_interactions(db): | ||
interaction1 = models.Interaction(settings={"prompt": "something"}) | ||
interaction2 = models.Interaction(settings={"prompt": "something else"}) | ||
db.add(interaction1) | ||
db.add(interaction2) | ||
db.commit() | ||
|
||
response = client.get("/api/interactions") | ||
assert response.status_code == 200 | ||
assert len(response.json()) == 2 | ||
|
||
|
||
def test_create_interaction(): | ||
response = client.post( | ||
"/api/interactions", | ||
json={ | ||
"prompt": "something", | ||
}, | ||
) | ||
assert response.status_code == 200 | ||
assert response.json()["settings"] == { | ||
"prompt": "something", | ||
"model": "GPT3", | ||
"role": "System", | ||
} | ||
@pytest.mark.asyncio | ||
async def test_get_all_interactions(db): | ||
async with client.async_session() as db: | ||
try: | ||
await client.create_tables() | ||
interaction1 = models.Interaction(settings={"prompt": "something"}) | ||
interaction2 = models.Interaction(settings={"prompt": "something else"}) | ||
db.add(interaction1) | ||
db.add(interaction2) | ||
await db.commit() | ||
|
||
response = client.client.get("/api/interactions") | ||
assert response.status_code == 200 | ||
assert len(response.json()) == 2 | ||
finally: | ||
await client.drop_tables() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_interaction(): | ||
async with client.async_session() as db: | ||
try: | ||
await client.create_tables() | ||
response = client.client.post( | ||
"/api/interactions", | ||
json={ | ||
"prompt": "something", | ||
}, | ||
) | ||
assert response.status_code == 200 | ||
assert response.json()["settings"] == { | ||
"prompt": "something", | ||
"model": "gpt-3.5-turbo", | ||
"role": "System", | ||
} | ||
finally: | ||
await client.drop_tables() |