Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Jun 16, 2024
1 parent 0add099 commit 52446ea
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 38 deletions.
68 changes: 41 additions & 27 deletions linguaphoto/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,17 @@ def cf(self) -> CloudFrontClient:

async def _init_dynamodb(self, session: aioboto3.Session) -> Self:
db = session.resource("dynamodb")
await db.__aenter__()
self.__db = db
self.__db = await db.__aenter__()
return self

async def _init_cloudfront(self, session: aioboto3.Session) -> Self:
cf = session.client("cloudfront")
await cf.__aenter__()
self.__cf = cf
self.__cf = await cf.__aenter__()
return self

async def _init_s3(self, session: aioboto3.Session) -> Self:
s3 = session.client("s3")
await s3.__aenter__()
self.__s3 = s3
self.__s3 = await s3.__aenter__()
return self

async def _init_redis(self) -> Self:
Expand All @@ -75,8 +72,7 @@ async def _init_redis(self) -> Self:
port=settings.redis.port,
db=settings.redis.db,
)
await kv.__aenter__()
self.__kv = kv
self.__kv = await kv.__aenter__()
return self

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -106,7 +102,7 @@ async def _create_dynamodb_table(
self,
name: str,
keys: list[tuple[str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]],
gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] = [],
gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] | None = None,
deletion_protection: bool = False,
) -> None:
"""Creates a table in the Dynamo database if a table of that name does not already exist.
Expand All @@ -123,26 +119,44 @@ async def _create_dynamodb_table(
"""
try:
await self.db.meta.client.describe_table(TableName=name)
logger.info("Found existing table %s", name)

except ClientError:
logger.info("Creating %s table", name)
table = await self.db.create_table(
AttributeDefinitions=[
{"AttributeName": n, "AttributeType": t}
for n, t in itertools.chain(((n, t) for (n, t, _) in keys), ((n, t) for _, n, t, _ in gsis))
],
TableName=name,
KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys],
GlobalSecondaryIndexes=[
{
"IndexName": i,
"KeySchema": [{"AttributeName": n, "KeyType": t}],
"Projection": {"ProjectionType": "ALL"},
}
for i, n, _, t in gsis
],
DeletionProtectionEnabled=deletion_protection,
BillingMode="PAY_PER_REQUEST",
)

if gsis is None:
table = await self.db.create_table(
AttributeDefinitions=[
{"AttributeName": n, "AttributeType": t} for n, t in ((n, t) for (n, t, _) in keys)
],
TableName=name,
KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys],
DeletionProtectionEnabled=deletion_protection,
BillingMode="PAY_PER_REQUEST",
)

else:
table = await self.db.create_table(
AttributeDefinitions=[
{"AttributeName": n, "AttributeType": t}
for n, t in itertools.chain(((n, t) for (n, t, _) in keys), ((n, t) for _, n, t, _ in gsis))
],
TableName=name,
KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys],
GlobalSecondaryIndexes=(
[
{
"IndexName": i,
"KeySchema": [{"AttributeName": n, "KeyType": t}],
"Projection": {"ProjectionType": "ALL"},
}
for i, n, _, t in gsis
]
),
DeletionProtectionEnabled=deletion_protection,
BillingMode="PAY_PER_REQUEST",
)

await table.wait_until_exists()

async def _delete_dynamodb_table(self, name: str) -> None:
Expand Down
46 changes: 44 additions & 2 deletions linguaphoto/db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Defines base tools for interacting with the database."""

import argparse
import asyncio
import logging
import uuid
from typing import AsyncGenerator, Self

import argparse
from linguaphoto.crud.base import BaseCrud
from linguaphoto.crud.images import ImagesCrud
from linguaphoto.crud.users import UserCrud
from linguaphoto.model import User
from linguaphoto.settings import settings


class Crud(
Expand Down Expand Up @@ -45,7 +48,6 @@ async def create_tables(crud: Crud | None = None, deletion_protection: bool = Fa
],
gsis=[
("emailIndex", "email", "S", "HASH"),
("usernameIndex", "username", "S", "HASH"),
],
deletion_protection=deletion_protection,
),
Expand All @@ -54,10 +56,50 @@ async def create_tables(crud: Crud | None = None, deletion_protection: bool = Fa
keys=[
("image_id", "S", "HASH"),
],
gsis=[
("userIndex", "user_id", "S", "HASH"),
],
deletion_protection=deletion_protection,
),
)


async def delete_tables(crud: Crud | None = None) -> None:
"""Deletes all of the database tables.
Args:
crud: The top-level CRUD class.
"""
logging.basicConfig(level=logging.INFO)

if crud is None:
async with Crud() as crud:
await delete_tables(crud)

else:
await asyncio.gather(
crud._delete_dynamodb_table("Users"),
crud._delete_dynamodb_table("Images"),
)


async def populate_with_dummy_data(crud: Crud | None = None) -> None:
"""Populates the database with dummy data.
Args:
crud: The top-level CRUD class.
"""
logging.basicConfig(level=logging.INFO)

if crud is None:
async with Crud() as crud:
await populate_with_dummy_data(crud)

else:
assert (test_user := settings.user.test_user) is not None
await crud.add_user(user=User(user_id=str(uuid.uuid4()), email=test_user.email))


async def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("action", choices=["create", "delete", "populate"])
Expand Down
1 change: 1 addition & 0 deletions linguaphoto/settings/configs/local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ crypto:
jwt_secret: fakeJwtSecret
site:
homepage: http://127.0.0.1:3000
aws:
image_bucket_id: linguaphoto-images
user:
test_user:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ include = '\.pyi?$'

[tool.pytest.ini_options]

addopts = "-rx -rf -x -q --full-trace"
testpaths = ["tests"]

markers = [
Expand Down
16 changes: 8 additions & 8 deletions tests/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
import asyncio

from fastapi.testclient import TestClient
from pytest_mock.plugin import MockType

from linguaphoto.db import create_tables
from linguaphoto.settings import settings


def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType) -> None:
def test_user_auth_functions(app_client: TestClient) -> None:
asyncio.run(create_tables())

test_username = "testusername"
test_password = "ccccc@#$bhui1324frhnund!!@#$"
assert (test_user := settings.user.test_user) is not None

# Attempts to log in before creating the user.
response = app_client.post("/users/login", json={"username": test_username, "password": test_password})
response = app_client.post("/users/google", json={"token": test_user.google_token})
assert response.status_code == 200, response.json()
assert mock_send_email.call_count == 1
api_key = response.json()["api_key"]

# Checks that without the API key we get a 401 response.
response = app_client.get("/users/me")
Expand All @@ -27,7 +26,7 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType)
# Checks that with the API key we get a 200 response.
response = app_client.get("/users/me", headers={"Authorization": f"Bearer {api_key}"})
assert response.status_code == 200, response.json()
assert response.json()["email"] == test_email
assert response.json()["email"] == test_user.email

# Checks that we can't log the user out without the API key.
response = app_client.delete("/users/logout")
Expand All @@ -44,8 +43,9 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType)
assert response.json()["detail"] == "User not found"

# Log the user back in, getting new API key.
response = app_client.post("/users/otp", json={"payload": otp.encode()})
response = app_client.post("/users/google", json={"token": test_user.google_token})
assert response.status_code == 200, response.json()
api_key = response.json()["api_key"]

# Delete the user using the new API key.
response = app_client.delete("/users/me", headers={"Authorization": f"Bearer {api_key}"})
Expand Down

0 comments on commit 52446ea

Please sign in to comment.