From 4cdcfafe9230f7bcd502dee7ed05be52d2155d9e Mon Sep 17 00:00:00 2001 From: dmb225 Date: Mon, 7 Oct 2024 17:01:43 +0200 Subject: [PATCH] Refactor database inferface --- .gitignore | 3 ++ src/application/entities/user.py | 20 +++++++++++ src/application/interfaces/database.py | 19 ++++------- src/infrastructure/databases/in_memory.py | 33 ++++++++++--------- .../infrastrucure/databases/test_in_memory.py | 22 ++++++------- 5 files changed, 57 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index 0590f8f..c534511 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,8 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# Vs code +.vscode/ + # Diff file *.diff diff --git a/src/application/entities/user.py b/src/application/entities/user.py index b832232..1484e6c 100644 --- a/src/application/entities/user.py +++ b/src/application/entities/user.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any, Self from uuid import UUID, uuid4 @@ -9,3 +10,22 @@ class User: password: str confirmed: bool id: UUID = field(default_factory=uuid4) + + def to_dict(self) -> dict[str, Any]: + return { + "id": str(self.id), + "name": self.name, + "email": self.email, + "password": self.password, + "confirmed": self.confirmed, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Self: + return cls( + id=UUID(data["id"]), + name=data["name"], + email=data["email"], + password=data["password"], + confirmed=data["confirmed"], + ) diff --git a/src/application/interfaces/database.py b/src/application/interfaces/database.py index df8f7e8..e914d00 100644 --- a/src/application/interfaces/database.py +++ b/src/application/interfaces/database.py @@ -1,26 +1,19 @@ from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar, Optional +from typing import Any, Optional from uuid import UUID -class HasID(Protocol): - id: UUID # Enforce that any entity passed to the repository must have an `id` attribute - - -T = TypeVar("T", bound=HasID) - - -class Database(ABC, Generic[T]): +class Database(ABC): @abstractmethod - def add(self, entity: T) -> None: + def add(self, entity: dict) -> None: pass @abstractmethod - def get(self, id: UUID) -> Optional[T]: + def get(self, id: UUID) -> Optional[dict[str, Any]]: pass @abstractmethod - def update(self, entity: T) -> None: + def update(self, entity: dict[str, Any]) -> None: pass @abstractmethod @@ -28,5 +21,5 @@ def delete(self, id: UUID) -> None: pass @abstractmethod - def list_all(self) -> list[T]: + def list_all(self) -> list[dict[str, Any]]: pass diff --git a/src/infrastructure/databases/in_memory.py b/src/infrastructure/databases/in_memory.py index 823baa7..8fb37a4 100644 --- a/src/infrastructure/databases/in_memory.py +++ b/src/infrastructure/databases/in_memory.py @@ -1,26 +1,27 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID -from application.interfaces.database import T, Database +from application.interfaces.database import Database -class InMemoryDatabase(Database[T]): - def __init__(self) -> None: - self._data: dict[UUID, T] = {} +class InMemoryDatabase(Database): + def __init__(self, model_name: str) -> None: + self.model_name = model_name + self.storage = {} - def add(self, entity: T) -> None: - self._data[entity.id] = entity + def add(self, entity_data: dict[str, Any]) -> None: + self.storage[entity_data["id"]] = entity_data - def get(self, id: UUID) -> Optional[T]: - return self._data.get(id) + def get(self, id: UUID) -> Optional[dict]: + return self.storage.get(str(id)) - def update(self, entity: T) -> None: - if entity.id in self._data: - self._data[entity.id] = entity + def update(self, entity_data: dict[str, Any]) -> None: + if str(entity_data["id"]) in self.storage: + self.storage[str(entity_data["id"])] = entity_data def delete(self, id: UUID) -> None: - if id in self._data: - del self._data[id] + if str(id) in self.storage: + del self.storage[str(id)] - def list_all(self) -> list[T]: - return list(self._data.values()) + def list_all(self) -> list[dict[str, Any]]: + return list(self.storage.values()) diff --git a/src/tests/unit/infrastrucure/databases/test_in_memory.py b/src/tests/unit/infrastrucure/databases/test_in_memory.py index 3e57364..ebe56c7 100644 --- a/src/tests/unit/infrastrucure/databases/test_in_memory.py +++ b/src/tests/unit/infrastrucure/databases/test_in_memory.py @@ -5,11 +5,11 @@ @pytest.fixture -def in_memory_db() -> InMemoryDatabase[User]: - return InMemoryDatabase[User]() +def in_memory_db() -> InMemoryDatabase: + return InMemoryDatabase("User") -def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) -> None: +def test_in_memory_database_operations(in_memory_db: InMemoryDatabase) -> None: # Test data user1 = User(name="John Doe", email="john@example.com", password="password123", confirmed=True) user2 = User( @@ -17,27 +17,27 @@ def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) -> ) # 1. Add user1 - in_memory_db.add(user1) - fetched_user1 = in_memory_db.get(user1.id) + in_memory_db.add(user1.to_dict()) + fetched_user1 = User.from_dict(in_memory_db.get(user1.id)) assert fetched_user1 is not None assert fetched_user1.name == user1.name assert fetched_user1.email == user1.email # 2. Update user1's name user1.name = "John Doe Updated" - in_memory_db.update(user1) - updated_user1 = in_memory_db.get(user1.id) + in_memory_db.update(user1.to_dict()) + updated_user1 = User.from_dict(in_memory_db.get(user1.id)) assert updated_user1 assert updated_user1.name == "John Doe Updated" # 3. Add user2 - in_memory_db.add(user2) - fetched_user2 = in_memory_db.get(user2.id) + in_memory_db.add(user2.to_dict()) + fetched_user2 = User.from_dict(in_memory_db.get(user2.id)) assert fetched_user2 is not None assert fetched_user2.name == user2.name # 4. List all users - all_users = in_memory_db.list_all() + all_users = [User.from_dict(user) for user in in_memory_db.list_all()] assert len(all_users) == 2 assert user1 in all_users assert user2 in all_users @@ -48,7 +48,7 @@ def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) -> assert deleted_user1 is None # 6. Check remaining users - remaining_users = in_memory_db.list_all() + remaining_users = [User.from_dict(user) for user in in_memory_db.list_all()] assert len(remaining_users) == 1 assert user2 in remaining_users assert user1 not in remaining_users