Skip to content

Commit

Permalink
Refactor database inferface
Browse files Browse the repository at this point in the history
  • Loading branch information
dmb225 committed Oct 7, 2024
1 parent 2974a3a commit 4cdcfaf
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 40 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Vs code
.vscode/

# Diff file
*.diff
20 changes: 20 additions & 0 deletions src/application/entities/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Self
from uuid import UUID, uuid4


Expand All @@ -9,3 +10,22 @@ class User:
password: str
confirmed: bool
id: UUID = field(default_factory=uuid4)

def to_dict(self) -> dict[str, Any]:
return {
"id": str(self.id),
"name": self.name,
"email": self.email,
"password": self.password,
"confirmed": self.confirmed,
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
return cls(
id=UUID(data["id"]),
name=data["name"],
email=data["email"],
password=data["password"],
confirmed=data["confirmed"],
)
19 changes: 6 additions & 13 deletions src/application/interfaces/database.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
from abc import ABC, abstractmethod
from typing import Generic, Protocol, TypeVar, Optional
from typing import Any, Optional
from uuid import UUID


class HasID(Protocol):
id: UUID # Enforce that any entity passed to the repository must have an `id` attribute


T = TypeVar("T", bound=HasID)


class Database(ABC, Generic[T]):
class Database(ABC):
@abstractmethod
def add(self, entity: T) -> None:
def add(self, entity: dict) -> None:
pass

@abstractmethod
def get(self, id: UUID) -> Optional[T]:
def get(self, id: UUID) -> Optional[dict[str, Any]]:
pass

@abstractmethod
def update(self, entity: T) -> None:
def update(self, entity: dict[str, Any]) -> None:
pass

@abstractmethod
def delete(self, id: UUID) -> None:
pass

@abstractmethod
def list_all(self) -> list[T]:
def list_all(self) -> list[dict[str, Any]]:
pass
33 changes: 17 additions & 16 deletions src/infrastructure/databases/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from typing import Optional
from typing import Any, Optional
from uuid import UUID

from application.interfaces.database import T, Database
from application.interfaces.database import Database


class InMemoryDatabase(Database[T]):
def __init__(self) -> None:
self._data: dict[UUID, T] = {}
class InMemoryDatabase(Database):
def __init__(self, model_name: str) -> None:
self.model_name = model_name
self.storage = {}

def add(self, entity: T) -> None:
self._data[entity.id] = entity
def add(self, entity_data: dict[str, Any]) -> None:
self.storage[entity_data["id"]] = entity_data

def get(self, id: UUID) -> Optional[T]:
return self._data.get(id)
def get(self, id: UUID) -> Optional[dict]:
return self.storage.get(str(id))

def update(self, entity: T) -> None:
if entity.id in self._data:
self._data[entity.id] = entity
def update(self, entity_data: dict[str, Any]) -> None:
if str(entity_data["id"]) in self.storage:
self.storage[str(entity_data["id"])] = entity_data

def delete(self, id: UUID) -> None:
if id in self._data:
del self._data[id]
if str(id) in self.storage:
del self.storage[str(id)]

def list_all(self) -> list[T]:
return list(self._data.values())
def list_all(self) -> list[dict[str, Any]]:
return list(self.storage.values())
22 changes: 11 additions & 11 deletions src/tests/unit/infrastrucure/databases/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,39 @@


@pytest.fixture
def in_memory_db() -> InMemoryDatabase[User]:
return InMemoryDatabase[User]()
def in_memory_db() -> InMemoryDatabase:
return InMemoryDatabase("User")


def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) -> None:
def test_in_memory_database_operations(in_memory_db: InMemoryDatabase) -> None:
# Test data
user1 = User(name="John Doe", email="[email protected]", password="password123", confirmed=True)
user2 = User(
name="Jane Smith", email="[email protected]", password="password456", confirmed=False
)

# 1. Add user1
in_memory_db.add(user1)
fetched_user1 = in_memory_db.get(user1.id)
in_memory_db.add(user1.to_dict())
fetched_user1 = User.from_dict(in_memory_db.get(user1.id))
assert fetched_user1 is not None
assert fetched_user1.name == user1.name
assert fetched_user1.email == user1.email

# 2. Update user1's name
user1.name = "John Doe Updated"
in_memory_db.update(user1)
updated_user1 = in_memory_db.get(user1.id)
in_memory_db.update(user1.to_dict())
updated_user1 = User.from_dict(in_memory_db.get(user1.id))
assert updated_user1
assert updated_user1.name == "John Doe Updated"

# 3. Add user2
in_memory_db.add(user2)
fetched_user2 = in_memory_db.get(user2.id)
in_memory_db.add(user2.to_dict())
fetched_user2 = User.from_dict(in_memory_db.get(user2.id))
assert fetched_user2 is not None
assert fetched_user2.name == user2.name

# 4. List all users
all_users = in_memory_db.list_all()
all_users = [User.from_dict(user) for user in in_memory_db.list_all()]
assert len(all_users) == 2
assert user1 in all_users
assert user2 in all_users
Expand All @@ -48,7 +48,7 @@ def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) ->
assert deleted_user1 is None

# 6. Check remaining users
remaining_users = in_memory_db.list_all()
remaining_users = [User.from_dict(user) for user in in_memory_db.list_all()]
assert len(remaining_users) == 1
assert user2 in remaining_users
assert user1 not in remaining_users
Expand Down

0 comments on commit 4cdcfaf

Please sign in to comment.