diff --git a/linguaphoto/crud/base.py b/linguaphoto/crud/base.py index 14ca71a..f7d6801 100644 --- a/linguaphoto/crud/base.py +++ b/linguaphoto/crud/base.py @@ -52,20 +52,17 @@ def cf(self) -> CloudFrontClient: async def _init_dynamodb(self, session: aioboto3.Session) -> Self: db = session.resource("dynamodb") - await db.__aenter__() - self.__db = db + self.__db = await db.__aenter__() return self async def _init_cloudfront(self, session: aioboto3.Session) -> Self: cf = session.client("cloudfront") - await cf.__aenter__() - self.__cf = cf + self.__cf = await cf.__aenter__() return self async def _init_s3(self, session: aioboto3.Session) -> Self: s3 = session.client("s3") - await s3.__aenter__() - self.__s3 = s3 + self.__s3 = await s3.__aenter__() return self async def _init_redis(self) -> Self: @@ -75,8 +72,7 @@ async def _init_redis(self) -> Self: port=settings.redis.port, db=settings.redis.db, ) - await kv.__aenter__() - self.__kv = kv + self.__kv = await kv.__aenter__() return self async def __aenter__(self) -> Self: @@ -106,7 +102,7 @@ async def _create_dynamodb_table( self, name: str, keys: list[tuple[str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]], - gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] = [], + gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] | None = None, deletion_protection: bool = False, ) -> None: """Creates a table in the Dynamo database if a table of that name does not already exist. @@ -123,26 +119,44 @@ async def _create_dynamodb_table( """ try: await self.db.meta.client.describe_table(TableName=name) + logger.info("Found existing table %s", name) + except ClientError: logger.info("Creating %s table", name) - table = await self.db.create_table( - AttributeDefinitions=[ - {"AttributeName": n, "AttributeType": t} - for n, t in itertools.chain(((n, t) for (n, t, _) in keys), ((n, t) for _, n, t, _ in gsis)) - ], - TableName=name, - KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], - GlobalSecondaryIndexes=[ - { - "IndexName": i, - "KeySchema": [{"AttributeName": n, "KeyType": t}], - "Projection": {"ProjectionType": "ALL"}, - } - for i, n, _, t in gsis - ], - DeletionProtectionEnabled=deletion_protection, - BillingMode="PAY_PER_REQUEST", - ) + + if gsis is None: + table = await self.db.create_table( + AttributeDefinitions=[ + {"AttributeName": n, "AttributeType": t} for n, t in ((n, t) for (n, t, _) in keys) + ], + TableName=name, + KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], + DeletionProtectionEnabled=deletion_protection, + BillingMode="PAY_PER_REQUEST", + ) + + else: + table = await self.db.create_table( + AttributeDefinitions=[ + {"AttributeName": n, "AttributeType": t} + for n, t in itertools.chain(((n, t) for (n, t, _) in keys), ((n, t) for _, n, t, _ in gsis)) + ], + TableName=name, + KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], + GlobalSecondaryIndexes=( + [ + { + "IndexName": i, + "KeySchema": [{"AttributeName": n, "KeyType": t}], + "Projection": {"ProjectionType": "ALL"}, + } + for i, n, _, t in gsis + ] + ), + DeletionProtectionEnabled=deletion_protection, + BillingMode="PAY_PER_REQUEST", + ) + await table.wait_until_exists() async def _delete_dynamodb_table(self, name: str) -> None: diff --git a/linguaphoto/db.py b/linguaphoto/db.py index 8399988..52831e7 100644 --- a/linguaphoto/db.py +++ b/linguaphoto/db.py @@ -1,13 +1,16 @@ """Defines base tools for interacting with the database.""" +import argparse import asyncio import logging +import uuid from typing import AsyncGenerator, Self -import argparse from linguaphoto.crud.base import BaseCrud from linguaphoto.crud.images import ImagesCrud from linguaphoto.crud.users import UserCrud +from linguaphoto.model import User +from linguaphoto.settings import settings class Crud( @@ -45,7 +48,6 @@ async def create_tables(crud: Crud | None = None, deletion_protection: bool = Fa ], gsis=[ ("emailIndex", "email", "S", "HASH"), - ("usernameIndex", "username", "S", "HASH"), ], deletion_protection=deletion_protection, ), @@ -54,10 +56,50 @@ async def create_tables(crud: Crud | None = None, deletion_protection: bool = Fa keys=[ ("image_id", "S", "HASH"), ], + gsis=[ + ("userIndex", "user_id", "S", "HASH"), + ], + deletion_protection=deletion_protection, ), ) +async def delete_tables(crud: Crud | None = None) -> None: + """Deletes all of the database tables. + + Args: + crud: The top-level CRUD class. + """ + logging.basicConfig(level=logging.INFO) + + if crud is None: + async with Crud() as crud: + await delete_tables(crud) + + else: + await asyncio.gather( + crud._delete_dynamodb_table("Users"), + crud._delete_dynamodb_table("Images"), + ) + + +async def populate_with_dummy_data(crud: Crud | None = None) -> None: + """Populates the database with dummy data. + + Args: + crud: The top-level CRUD class. + """ + logging.basicConfig(level=logging.INFO) + + if crud is None: + async with Crud() as crud: + await populate_with_dummy_data(crud) + + else: + assert (test_user := settings.user.test_user) is not None + await crud.add_user(user=User(user_id=str(uuid.uuid4()), email=test_user.email)) + + async def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("action", choices=["create", "delete", "populate"]) diff --git a/linguaphoto/settings/configs/local.yaml b/linguaphoto/settings/configs/local.yaml index 2102846..9f65b57 100644 --- a/linguaphoto/settings/configs/local.yaml +++ b/linguaphoto/settings/configs/local.yaml @@ -2,6 +2,7 @@ crypto: jwt_secret: fakeJwtSecret site: homepage: http://127.0.0.1:3000 +aws: image_bucket_id: linguaphoto-images user: test_user: diff --git a/pyproject.toml b/pyproject.toml index 1f4940b..2980288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ include = '\.pyi?$' [tool.pytest.ini_options] -addopts = "-rx -rf -x -q --full-trace" testpaths = ["tests"] markers = [ diff --git a/tests/test_users.py b/tests/test_users.py index 9ead2f4..85963f1 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -3,21 +3,20 @@ import asyncio from fastapi.testclient import TestClient -from pytest_mock.plugin import MockType from linguaphoto.db import create_tables +from linguaphoto.settings import settings -def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType) -> None: +def test_user_auth_functions(app_client: TestClient) -> None: asyncio.run(create_tables()) - test_username = "testusername" - test_password = "ccccc@#$bhui1324frhnund!!@#$" + assert (test_user := settings.user.test_user) is not None # Attempts to log in before creating the user. - response = app_client.post("/users/login", json={"username": test_username, "password": test_password}) + response = app_client.post("/users/google", json={"token": test_user.google_token}) assert response.status_code == 200, response.json() - assert mock_send_email.call_count == 1 + api_key = response.json()["api_key"] # Checks that without the API key we get a 401 response. response = app_client.get("/users/me") @@ -27,7 +26,7 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType) # Checks that with the API key we get a 200 response. response = app_client.get("/users/me", headers={"Authorization": f"Bearer {api_key}"}) assert response.status_code == 200, response.json() - assert response.json()["email"] == test_email + assert response.json()["email"] == test_user.email # Checks that we can't log the user out without the API key. response = app_client.delete("/users/logout") @@ -44,8 +43,9 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType) assert response.json()["detail"] == "User not found" # Log the user back in, getting new API key. - response = app_client.post("/users/otp", json={"payload": otp.encode()}) + response = app_client.post("/users/google", json={"token": test_user.google_token}) assert response.status_code == 200, response.json() + api_key = response.json()["api_key"] # Delete the user using the new API key. response = app_client.delete("/users/me", headers={"Authorization": f"Bearer {api_key}"})