-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
57 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, Protocol, TypeVar, Optional | ||
from typing import Any, Optional | ||
from uuid import UUID | ||
|
||
|
||
class HasID(Protocol): | ||
id: UUID # Enforce that any entity passed to the repository must have an `id` attribute | ||
|
||
|
||
T = TypeVar("T", bound=HasID) | ||
|
||
|
||
class Database(ABC, Generic[T]): | ||
class Database(ABC): | ||
@abstractmethod | ||
def add(self, entity: T) -> None: | ||
def add(self, entity: dict) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def get(self, id: UUID) -> Optional[T]: | ||
def get(self, id: UUID) -> Optional[dict[str, Any]]: | ||
pass | ||
|
||
@abstractmethod | ||
def update(self, entity: T) -> None: | ||
def update(self, entity: dict[str, Any]) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def delete(self, id: UUID) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def list_all(self) -> list[T]: | ||
def list_all(self) -> list[dict[str, Any]]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,27 @@ | ||
from typing import Optional | ||
from typing import Any, Optional | ||
from uuid import UUID | ||
|
||
from application.interfaces.database import T, Database | ||
from application.interfaces.database import Database | ||
|
||
|
||
class InMemoryDatabase(Database[T]): | ||
def __init__(self) -> None: | ||
self._data: dict[UUID, T] = {} | ||
class InMemoryDatabase(Database): | ||
def __init__(self, model_name: str) -> None: | ||
self.model_name = model_name | ||
self.storage = {} | ||
|
||
def add(self, entity: T) -> None: | ||
self._data[entity.id] = entity | ||
def add(self, entity_data: dict[str, Any]) -> None: | ||
self.storage[entity_data["id"]] = entity_data | ||
|
||
def get(self, id: UUID) -> Optional[T]: | ||
return self._data.get(id) | ||
def get(self, id: UUID) -> Optional[dict]: | ||
return self.storage.get(str(id)) | ||
|
||
def update(self, entity: T) -> None: | ||
if entity.id in self._data: | ||
self._data[entity.id] = entity | ||
def update(self, entity_data: dict[str, Any]) -> None: | ||
if str(entity_data["id"]) in self.storage: | ||
self.storage[str(entity_data["id"])] = entity_data | ||
|
||
def delete(self, id: UUID) -> None: | ||
if id in self._data: | ||
del self._data[id] | ||
if str(id) in self.storage: | ||
del self.storage[str(id)] | ||
|
||
def list_all(self) -> list[T]: | ||
return list(self._data.values()) | ||
def list_all(self) -> list[dict[str, Any]]: | ||
return list(self.storage.values()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,39 +5,39 @@ | |
|
||
|
||
@pytest.fixture | ||
def in_memory_db() -> InMemoryDatabase[User]: | ||
return InMemoryDatabase[User]() | ||
def in_memory_db() -> InMemoryDatabase: | ||
return InMemoryDatabase("User") | ||
|
||
|
||
def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) -> None: | ||
def test_in_memory_database_operations(in_memory_db: InMemoryDatabase) -> None: | ||
# Test data | ||
user1 = User(name="John Doe", email="[email protected]", password="password123", confirmed=True) | ||
user2 = User( | ||
name="Jane Smith", email="[email protected]", password="password456", confirmed=False | ||
) | ||
|
||
# 1. Add user1 | ||
in_memory_db.add(user1) | ||
fetched_user1 = in_memory_db.get(user1.id) | ||
in_memory_db.add(user1.to_dict()) | ||
fetched_user1 = User.from_dict(in_memory_db.get(user1.id)) | ||
assert fetched_user1 is not None | ||
assert fetched_user1.name == user1.name | ||
assert fetched_user1.email == user1.email | ||
|
||
# 2. Update user1's name | ||
user1.name = "John Doe Updated" | ||
in_memory_db.update(user1) | ||
updated_user1 = in_memory_db.get(user1.id) | ||
in_memory_db.update(user1.to_dict()) | ||
updated_user1 = User.from_dict(in_memory_db.get(user1.id)) | ||
assert updated_user1 | ||
assert updated_user1.name == "John Doe Updated" | ||
|
||
# 3. Add user2 | ||
in_memory_db.add(user2) | ||
fetched_user2 = in_memory_db.get(user2.id) | ||
in_memory_db.add(user2.to_dict()) | ||
fetched_user2 = User.from_dict(in_memory_db.get(user2.id)) | ||
assert fetched_user2 is not None | ||
assert fetched_user2.name == user2.name | ||
|
||
# 4. List all users | ||
all_users = in_memory_db.list_all() | ||
all_users = [User.from_dict(user) for user in in_memory_db.list_all()] | ||
assert len(all_users) == 2 | ||
assert user1 in all_users | ||
assert user2 in all_users | ||
|
@@ -48,7 +48,7 @@ def test_in_memory_database_operations(in_memory_db: InMemoryDatabase[User]) -> | |
assert deleted_user1 is None | ||
|
||
# 6. Check remaining users | ||
remaining_users = in_memory_db.list_all() | ||
remaining_users = [User.from_dict(user) for user in in_memory_db.list_all()] | ||
assert len(remaining_users) == 1 | ||
assert user2 in remaining_users | ||
assert user1 not in remaining_users | ||
|