Skip to content

Commit

Permalink
Merge pull request #35 from CCRI-POPROX/karl/refactor/repository-regi…
Browse files Browse the repository at this point in the history
…stration

Auto-register `DatabaseRepository` sub-classes to create with decorator
  • Loading branch information
karlhigley authored Aug 20, 2024
2 parents bf94e40 + 329d3b8 commit ce84f35
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
32 changes: 31 additions & 1 deletion src/poprox_storage/repositories/data_stores/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any
from functools import wraps
from typing import Any, get_type_hints
from uuid import UUID

from sqlalchemy import (
Expand All @@ -16,10 +17,39 @@
logger.setLevel(logging.DEBUG)


def inject_repos(handler):
@wraps(handler)
def wrapper(event, context):
with DB_ENGINE.connect() as conn:
params: dict[str, type] = get_type_hints(handler)
# remove event, context, and return type if they were annotated.
params.pop("event", None)
params.pop("context", None)
params.pop("return", None)

repos = dict()
for param, class_obj in params.items():
if class_obj in DatabaseRepository._repository_types:
repos[param] = class_obj(conn)

return handler(event, context, **repos)

return wrapper


class DatabaseRepository:
_repository_types = set()

def __init__(self, connection: Connection):
self.conn: Connection = connection

def __init_subclass__(cls, *args, **kwargs):
"""
Gets called once for each loaded class that sub-classes DatabaseRepository
"""

cls._repository_types.add(cls)

def _load_tables(self, *args) -> dict[str, Table]:
metadata = MetaData()
tables = {}
Expand Down
16 changes: 16 additions & 0 deletions tests/test_repositories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from poprox_storage.repositories import DbArticleRepository
from poprox_storage.repositories.data_stores.db import inject_repos
from poprox_storage.repositories.newsletters import DbNewsletterRepository


@inject_repos
def example(event, context, article: DbArticleRepository, newsletter_repo: DbNewsletterRepository):
return article, newsletter_repo


def test_repositories():
retval = example({}, {})
assert isinstance(retval, tuple)
assert len(retval) == 2
assert isinstance(retval[0], DbArticleRepository)
assert isinstance(retval[1], DbNewsletterRepository)

0 comments on commit ce84f35

Please sign in to comment.