diff --git a/.gitignore b/.gitignore index 08132d3..f847706 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ cython_debug/ # Experiments experiments/ + +# MacOS +.DS_Store \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b0c4b53 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.8-slim-buster +EXPOSE 8080 + +ENV PYTHONUNBUFFERED 1 +ENV APP_HOME /app + +COPY requirements.txt . + +RUN pip install --no-cache-dir -U pip \ + && pip install --no-cache-dir -r requirements.txt + +WORKDIR ${APP_HOME} + +COPY src src +COPY main.py . +COPY docker-entrypoint.sh . +COPY .env . + +ENTRYPOINT ["/bin/bash"] +CMD ["./docker-entrypoint.sh"] diff --git a/Pipfile.lock b/Pipfile.lock index 78339db..cd0951b 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -425,4 +425,4 @@ "version": "==1.26.4" } } -} +} \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..49a1fb0 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,32 @@ +version: "3.5" +services: + user_app: + build: . + container_name: cdslab_models_app + env_file: .env + image: cdslab_user + networks: + - cdslab_models + ports: + - 5000:5000 + volumes: + - ./src:/app/src + + user_mongo: + container_name: cdslab_models_mongo + environment: + MONGO_INITDB_ROOT_USERNAME: cdsuser + MONGO_INITDB_ROOT_PASSWORD: cdspass + image: mongo:3-xenial + networks: + - cdslab_models + ports: + - 27017:27017 + volumes: + - /tmp/data/cdslab_models/:/data/db + + +networks: + cdslab_models: + name: cdslab_user + driver: bridge \ No newline at end of file diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100644 index 0000000..b1b4673 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +if [ -z "${PORT}" ] ; then PORT=8080; fi + +if [ -z "${HOST}" ] ; then HOST=0.0.0.0; fi + +uvicorn main:app --port=${PORT} --host=${HOST} --reload diff --git a/main.py b/main.py index 9a4e19b..d93bb78 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,18 @@ +from src.db.mongo import MongoClientSingleton import uvicorn + +from src.api import app from src.config import settings +from src.use_cases.cmodels import CmodelUseCases -if __name__ == '__main__': +__all__ = ['app'] +if __name__ == '__main__': + mongo_singleton = MongoClientSingleton() + CmodelUseCases(mongo_singleton).update_cmodels_collection() uvicorn.run( "src.api:app", - host=settings['HOST'], + host=settings.get('HOST'), port=int(settings['PORT']), reload=True, debug=True, diff --git a/requirements.txt b/requirements.txt index 7fbf992..101272f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ iniconfig==1.1.1 mccabe==0.6.1 mongomock==3.22.1 packaging==20.9; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' +pandas==1.2.4 pluggy==0.13.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' py==1.10.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' pycodestyle==2.7.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' @@ -33,6 +34,7 @@ pyparsing==2.4.7; python_version >= '2.6' and python_version not in '3.0, 3.1, 3 pytest-cov==2.11.1 pytest==6.2.3 python-dotenv==0.17.0 +python-multipart==0.0.5 requests==2.25.1 sentinels==1.0.0 six==1.15.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' diff --git a/src/api.py b/src/api.py index 5585804..a7b63df 100644 --- a/src/api.py +++ b/src/api.py @@ -7,18 +7,17 @@ app = FastAPI() - app.add_middleware( TrustedHostMiddleware, - allowed_hosts=settings["ALLOWED_HOSTS"].split(",") + allowed_hosts=settings.get("ALLOWED_HOSTS", "*").split(",") ) app.add_middleware( CORSMiddleware, allow_credentials=True, - allow_origins=settings["ALLOWED_ORIGINS"].split(","), - allow_methods=settings["ALLOWED_METHODS"].split(","), - allow_headers=settings["ALLOWED_HEADERS"].split(",") + allow_origins=settings.get("ALLOWED_ORIGINS", "*").split(","), + allow_methods=settings.get("ALLOWED_METHODS", "*").split(","), + allow_headers=settings.get("ALLOWED_HEADERS", "*").split(",") ) app.include_router( diff --git a/src/config.py b/src/config.py index 82e0d82..f2eb157 100644 --- a/src/config.py +++ b/src/config.py @@ -1,5 +1,5 @@ from dotenv import dotenv_values -settings = dotenv_values(".env") -db_config = dotenv_values(".db_config") +settings = dotenv_values('.env') +db_config = dotenv_values('.db_config') diff --git a/src/db/mongo.py b/src/db/mongo.py index ebcf18e..f4d81df 100644 --- a/src/db/mongo.py +++ b/src/db/mongo.py @@ -1,47 +1,87 @@ -from typing import Generator, Tuple +from typing import Generator, Optional, Tuple, Union + from pymongo import MongoClient +from pymongo.collection import Collection from pymongo.database import Database +import mongomock from src.config import db_config +from src.utils.patterns import Singleton + + +class MongoClientSingleton(metaclass=Singleton): + db_uri: Optional[str] = db_config.get('MONGO_URI'), + + def __init__( + self, + db_connection: MongoClient = MongoClient(db_uri), + db: Union[str, Database] = db_config.get('MONGO_DB'), + coll: Union[str, Collection] = db_config.get('CMODELS_COLL') + ) -> None: + self.db_connection = db_connection + self.db = db + self.coll = coll + + def api_get_db_connection(self) -> Generator[MongoClient, None, None]: + """Creates connection to Mongodb server. + + Yields + ------ + db_connection: MongoClient + Object containing the db connection. + """ + try: + yield self.db_connection + finally: + self.db_connection.close() + + def get_db( + self, + db: Optional[str] = None + ) -> Tuple[MongoClient, Database]: + """Gets Mongodb connection and database. + + Returns + ------- + db_connection : MongoClient + db : pymongo.database.Database + """ + if db and isinstance(db, str): + self.db = db + return self.db_connection, self.db + + def get_collection( + self, + coll: Optional[str] = None + ) -> Tuple[MongoClient, Collection]: + """ + Returns + ------- + db_connection: pymongo.MongoClient + coll: pymongo.collection.Collection + """ + if coll and isinstance(coll, str): + self.coll = coll + return self.db_connection, self.coll + + @property + def db(self): + return self._db + + @db.setter + def db(self, value): # noqa + if isinstance(value, str): + self._db = self.db_connection[value] + elif isinstance(value, Database) or isinstance(value, mongomock.Database): + self._db = value + @property + def coll(self): + return self._coll -def api_get_db_connection( - db_uri: str = db_config['MONGO_URI'], -) -> Generator[MongoClient, None, None]: - """Creates connection to Mongodb server. - - Yields - ------ - db_connection: MongoClient - Object containing the db connection. - """ - db_connection = MongoClient(db_uri) - - try: - yield db_connection - finally: - db_connection.close() - - -def get_db( - db_uri: str = db_config['MONGO_URI'], - db_name: str = db_config['MONGO_DB'] -) -> Tuple[MongoClient, Database]: - """Gets Mongodb connection and database. - - Parameters - ---------- - db_uri: str - Mongo server URI. - db_name: str - Mongo database name. - - Returns - ------- - db_connection : MongoClient - db : pymongo.database.Database - """ - db_connection = MongoClient(db_uri) - db: Database = db_connection[db_name] - - return db_connection, db + @coll.setter + def coll(self, value): # noqa + if isinstance(value, str): + self._coll = self.db[value] + elif isinstance(value, Collection) or isinstance(value, mongomock.Collection): + self._coll = value diff --git a/src/interfaces/cmodel_interface.py b/src/interfaces/cmodel_interface.py deleted file mode 100644 index ce0deb7..0000000 --- a/src/interfaces/cmodel_interface.py +++ /dev/null @@ -1,5 +0,0 @@ -from src.db import get_db_connection - - -def model_db(): - return get_db_connection().get_collection('cmodels') diff --git a/src/interfaces/cmodels.py b/src/interfaces/cmodels.py new file mode 100644 index 0000000..1f561a7 --- /dev/null +++ b/src/interfaces/cmodels.py @@ -0,0 +1,60 @@ +from datetime import datetime + +from src.db.mongo import MongoClientSingleton +from src.models.db.cmodels import ( + CompartmentalModel, + CompartmentalModelEnum +) +from src.interfaces.crud import MongoCRUD + + +class CModelsInterface: + + def __init__( + self, + mongo_singleton: MongoClientSingleton + ) -> None: + self.mongo_singleton = mongo_singleton + self.crud = MongoCRUD(self.mongo_singleton) + + def insert_one_cmodel_document(self, model: CompartmentalModel): + + cmodel_document = model.dict(by_alias=True) + + existent_model = self.crud.read(model.id) + + if existent_model: + pruned_existent_model = CModelsInterface._prune_db_document( + existent_model + ) + pruned_current_model = CModelsInterface._prune_db_document( + cmodel_document + ) + if pruned_existent_model == pruned_current_model: + # TODO: log cmodel exists f'Cmodel exists: {model.name}' + return existent_model + else: + model.updated_at = datetime.utcnow() + updated_model = self.crud.update( + model.id, + model.dict(by_alias=True) + ) + # TODO log updated cmodel f'Updated cmodel: {model.name}' + return updated_model + else: + model_inserted = self.crud.insert(cmodel_document) + # TODO: log created cmodel f'Created cmodel: {model.name}' + return model_inserted + + def insert_all_cmodel_documents(self): + return [ + self.insert_one_cmodel_document(model) + for model in CompartmentalModelEnum.values() + ] + + @staticmethod + def _prune_db_document(model_in_db: dict) -> dict: + if model_in_db: + model_in_db.pop('inserted_at') + model_in_db.pop('updated_at') + return model_in_db diff --git a/src/interfaces/crud.py b/src/interfaces/crud.py new file mode 100644 index 0000000..91522bd --- /dev/null +++ b/src/interfaces/crud.py @@ -0,0 +1,100 @@ +from typing import Any, Union +from bson.objectid import ObjectId + +from src.db.mongo import MongoClientSingleton + + +class MongoCRUD: + def __init__( + self, + mongo_singleton: MongoClientSingleton + ) -> None: + self.mongo_singleton = mongo_singleton + self.db_connection, self.collection = self.mongo_singleton.get_collection() + + def insert(self, document: dict): + """ + Parameters + ---------- + model + The compartmental models that will be save in the database + + Return + ---------- + model: pymongo object + """ + with self.db_connection: + try: + document['_id'] + except KeyError: + raise ValueError('self.insert(): document must have key "_id"') + + existent_document = self.collection.find_one( + self._id_to_dict(document['_id']) + ) + + if existent_document: + raise ValueError( + f'Document with _id={document["_id"]} already exists in' + 'collection {self.collection}. if you want to change the' + 'fields\' values please use self.update' + ) + # TODO: log + + return self.collection.insert_one(document) + + def read(self, _id: ObjectId) -> Union[Any, None]: + """Search for a specific model in ``self.collection``. + + Parameters + ---------- + query + Document's id. ``dict`` schema: ``{'_id': bson.ObjectID}`` + + Return + ---------- + model: pymongo object + Object containing the results of the search + """ + with self.db_connection: + return self.collection.find_one(self._id_to_dict(_id)) + + def update(self, _id: ObjectId, new_data: dict) -> bool: + """Update document in ``self.collection``. + + Parameters + ---------- + query + Document's id. ``dict`` schema: ``{'_id': bson.ObjectID}`` + data + Updated document fields + + Return + ---------- + False: + * If query has no information + * If is not possible to update the document's status + * If the query id doesn't match the one associated to ``new_data`` + True: + If the model has valid data and its status can be updated + """ + _id_dict = self._id_to_dict(_id) + with self.db_connection: + if not _id_dict: + return False + cmodel = self.collection.find_one(_id_dict) + if cmodel: + update_model = self.collection.update_one( + _id_dict, + {"$set": new_data}, + ) + if update_model: + return True + return False + + def delete(self, _id: ObjectId): + with self.db_connection: + return self.collection.delete_one(self._id_to_dict(_id)) + + def _id_to_dict(self, _id: ObjectId): + return {'_id': _id} if _id else None diff --git a/src/interfaces/simulation_interface.py b/src/interfaces/simulation_interface.py deleted file mode 100644 index 604a999..0000000 --- a/src/interfaces/simulation_interface.py +++ /dev/null @@ -1,5 +0,0 @@ -from src.db import get_db_connection - - -def simulation_db(): - return get_db_connection().get_collection('cmodels_simulation') diff --git a/src/models/db/__init__.py b/src/models/db/__init__.py index e69de29..04752da 100644 --- a/src/models/db/__init__.py +++ b/src/models/db/__init__.py @@ -0,0 +1,11 @@ +from .cmodels import ( + CompartmentalModelBase, + CompartmentalModel, + CModel, +) + +__all__ = [ + 'CompartmentalModelBase', + 'CompartmentalModel', + 'CModel', +] diff --git a/src/models/db/base_model.py b/src/models/db/base_model.py new file mode 100644 index 0000000..202c3f6 --- /dev/null +++ b/src/models/db/base_model.py @@ -0,0 +1,26 @@ +from datetime import datetime + +from pydantic import BaseModel, Field +from bson.objectid import ObjectId as BsonObjectId + + +class PydanticObjectId(BsonObjectId): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + if not isinstance(v, BsonObjectId): + raise TypeError('ObjectId required') + return v + + +class MetadataBaseDoc(BaseModel): + id: PydanticObjectId = Field(..., alias='_id') + inserted_at: datetime = datetime.utcnow() + updated_at: datetime = datetime.utcnow() + + class Config: + allow_population_by_field_name = True + extra = 'forbid' diff --git a/src/models/db/cmodels.py b/src/models/db/cmodels.py new file mode 100644 index 0000000..cc243cf --- /dev/null +++ b/src/models/db/cmodels.py @@ -0,0 +1,96 @@ +from enum import Enum +from typing import Dict, List + +from bson.objectid import ObjectId +from pydantic import BaseModel + +from .base_model import MetadataBaseDoc + + +class CompartmentalModelBase(BaseModel): + """Base Model for Compartmental Models + """ + name: str + """Name of the compartmental model""" + state_variables: List[str] + """Name of state variables of corresponding model""" + state_variables_units: Dict[str, str] + """Units of each state variable. The keys are the 'state_variables' array + elements + """ + parameters: List[str] + """Parameters of the corresponding model""" + parameters_units: Dict[str, str] + """Units of each parameter. The keys are the 'parameters' array elements""" + + +class CompartmentalModel(MetadataBaseDoc, CompartmentalModelBase): + pass + + +class CompartmentalModelEnum(Enum): + """Compartmental Models' Data. + + Each element of this ``Enum`` class contains all the essential information + on the corresponding Compartmental Model. Each element is a + :class:``CompartmentalModelBase`` object + """ + sir: CompartmentalModelBase = CompartmentalModel( + id=ObjectId('6083175ea91f5aacea234423'), + name='SIR', + state_variables=['S', 'I', 'R'], + state_variables_units={ + 'S': 'persons', + 'I': 'persons', + 'R': 'persons', + }, + parameters=['a', 'b'], + parameters_units={ + 'a': 'units of a', + 'b': 'units of b', + }, + ) + + seir: CompartmentalModelBase = CompartmentalModel( + id=ObjectId('6083176ca91f5aacea234424'), + name='SEIR', + state_variables=['S', 'E', 'I', 'R'], + state_variables_units={ + 'S': 'persons', + 'E': 'persons', + 'I': 'persons', + 'R': 'persons', + }, + parameters=['a', 'b'], + parameters_units={ + 'a': 'units of a', + 'b': 'units of b', + }, + ) + + seirv: CompartmentalModelBase = CompartmentalModel( + id=ObjectId('608317d0a91f5aacea234426'), + name='SEIRV', + state_variables=['S', 'E', 'I', 'R', 'V'], + state_variables_units={ + 'S': 'persons', + 'E': 'persons', + 'I': 'persons', + 'R': 'persons', + 'V': 'persons', + }, + parameters=['a', 'b'], + parameters_units={ + 'a': 'units of a', + 'b': 'units of b', + 'c': 'units of c' + }, + ) + + @classmethod + def values(cls) -> List[CompartmentalModel]: + return [m.value for m in cls] + + +class CModel(BaseModel): + model: CompartmentalModelEnum diff --git a/src/models/routers/simulation.py b/src/models/db/simulations.py similarity index 53% rename from src/models/routers/simulation.py rename to src/models/db/simulations.py index 2d87447..8a681b7 100644 --- a/src/models/routers/simulation.py +++ b/src/models/db/simulations.py @@ -1,20 +1,8 @@ from enum import Enum from datetime import datetime -from bson.objectid import ObjectId as BsonObjectId from pydantic import BaseModel - - -class PydanticObjectId(BsonObjectId): - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - if not isinstance(v, BsonObjectId): - raise TypeError('ObjectId required') - return v +from .base_model import PydanticObjectId class SimulationType(str, Enum): diff --git a/src/models/routers/__init__.py b/src/models/routers/__init__.py index 0f2ba5f..e69de29 100644 --- a/src/models/routers/__init__.py +++ b/src/models/routers/__init__.py @@ -1,20 +0,0 @@ -from .cmodel import ( - CompartmentalModelBase, - CompartmentalModel, - CModel, - AllCModels -) -from .simulation import ( - SimulationType, - SimulationConfig -) - - -__all__ = [ - 'CompartmentalModelBase', - 'CompartmentalModel', - 'CModel', - 'AllCModels', - 'SimulationType', - 'SimulationConfig', -] diff --git a/src/models/routers/cmodel.py b/src/models/routers/cmodel.py index 87151dc..e69de29 100644 --- a/src/models/routers/cmodel.py +++ b/src/models/routers/cmodel.py @@ -1,87 +0,0 @@ -from enum import Enum -from typing import Dict, List - -from pydantic import BaseModel - - -class CompartmentalModelBase(BaseModel): - """Base Model for Compartmental Models - """ - name: str - """Name of the compartmental model""" - state_variables: List[str] - """Name of state variables of corresponding model""" - state_variables_units: Dict[str, str] - """Units of each state variable. The keys are the 'state_variables' array - elements - """ - parameters: List[str] - """Parameters of the corresponding model""" - parameters_units: Dict[str, str] - """Units of each parameter. The keys are the 'parameters' array elements""" - - -class CompartmentalModel(Enum): - """Compartmental Models' Data. - - Each element of this ``Enum`` class contains all the essential information - on the corresponding Compartmental Model. Each element is a - :class:``CompartmentalModelBase`` object - """ - sir = CompartmentalModelBase( - name='SIR', - state_variables=['S', 'I', 'R'], - state_variables_units={ - 'S': 'persons', - 'I': 'persons', - 'R': 'persons', - }, - parameters=['a', 'b'], - parameters_units={ - 'a': 'units of a', - 'b': 'units of b', - }, - ) - - seir = CompartmentalModelBase( - name='SEIR', - state_variables=['S', 'E', 'I', 'R'], - state_variables_units={ - 'S': 'persons', - 'E': 'persons', - 'I': 'persons', - 'R': 'persons', - }, - parameters=['a', 'b'], - parameters_units={ - 'a': 'units of a', - 'b': 'units of b', - }, - ) - - seirv = CompartmentalModelBase( - name='SEIRV', - state_variables=['S', 'E', 'I', 'R', 'V'], - state_variables_units={ - 'S': 'persons', - 'E': 'persons', - 'I': 'persons', - 'R': 'persons', - 'V': 'persons', - }, - parameters=['a', 'b'], - parameters_units={ - 'a': 'units of a', - 'b': 'units of b', - }, - ) - - -class CModel(BaseModel): - model: CompartmentalModel - - -class AllCModels(BaseModel): - models: List[CompartmentalModel] = [ - model.value for model in CompartmentalModel - ] diff --git a/src/routers/main.py b/src/routers/main.py index ad0f8ce..afe7844 100644 --- a/src/routers/main.py +++ b/src/routers/main.py @@ -3,6 +3,7 @@ main_router_prefix = "" main_router = APIRouter(prefix=main_router_prefix) + @main_router.get('/') async def hello(request: Request): return {'hello': 'world'} diff --git a/src/use_cases/cmodels.py b/src/use_cases/cmodels.py new file mode 100644 index 0000000..4661c68 --- /dev/null +++ b/src/use_cases/cmodels.py @@ -0,0 +1,15 @@ +from src.db.mongo import MongoClientSingleton +from src.interfaces.cmodels import CModelsInterface + + +class CmodelUseCases: + + def __init__( + self, + mongo_singleton: MongoClientSingleton + ) -> None: + self.mongo_singleton = mongo_singleton + self.cmodels_interface = CModelsInterface(self.mongo_singleton) + + def update_cmodels_collection(self) -> None: + self.cmodels_interface.insert_all_cmodel_documents() diff --git a/src/utils/date_time.py b/src/utils/date_time.py new file mode 100644 index 0000000..4184541 --- /dev/null +++ b/src/utils/date_time.py @@ -0,0 +1,8 @@ +from datetime import datetime + + +class DateTime: + + @classmethod + def current_datetime(cls) -> datetime: + return datetime.utcnow() diff --git a/src/utils/patterns.py b/src/utils/patterns.py new file mode 100644 index 0000000..3776cb9 --- /dev/null +++ b/src/utils/patterns.py @@ -0,0 +1,7 @@ +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/tests/db/test_mongo.py b/tests/db/test_mongo.py index f371b0b..7d7caae 100644 --- a/tests/db/test_mongo.py +++ b/tests/db/test_mongo.py @@ -1,11 +1,10 @@ from unittest import TestCase from unittest.mock import patch, Mock -from mongomock import patch as db_patch -from pymongo.database import Database -from pymongo import MongoClient +import mongomock +from mongomock import patch as db_path -from src.db.mongo import api_get_db_connection, get_db +from src.db.mongo import MongoClientSingleton def solve_path(path: str): @@ -14,29 +13,45 @@ def solve_path(path: str): class MongoTestCase(TestCase): - server = "mongodb://mongodb0.example.com:27017" - @db_patch(servers=(('mongodb://mongodb.example.com', 27017),)) - @patch(solve_path('db_config')) - def test_get_db_connection(self, mock_config: Mock): - mock_config.get.side_effect = [MongoTestCase.server, "test_db"] + @db_path(servers=(('server.example.com', 27017),)) + def setUp(self): + self.connection_mock = mongomock.MongoClient('server.example.com') + self.db_mock = self.connection_mock.db + self.collection_mock = self.db_mock.collection + self.mongo_singleton_mock = MongoClientSingleton( + self.connection_mock, self.db_mock, self.collection_mock + ) + + def tearDown(self): + self.collection_mock.drop() + self.connection_mock.close() - db_connection_generator = api_get_db_connection(MongoTestCase.server) - db_connection = next(db_connection_generator) + @patch(solve_path('db_config')) + def test_get_db_connection(self, mock: Mock): + db_connection_gen = self.mongo_singleton_mock.api_get_db_connection() + db_connection = next(db_connection_gen) - self.assertIsInstance(db_connection, MongoClient) + self.assertIsInstance(db_connection, mongomock.MongoClient) - try: - next(db_connection_generator) - except StopIteration: - pass + with self.assertRaises(StopIteration): + next(db_connection_gen) - @db_patch(servers=(('mongodb://mongodb.example.com', 27017),)) @patch(solve_path('db_config')) def test_get_db(self, mock_config: Mock): - mock_config.get.side_effect = [MongoTestCase.server, "test_db"] + db_connection, db = self.mongo_singleton_mock.get_db( + "test_db" + ) + + self.assertIsInstance(db_connection, mongomock.MongoClient) + self.assertIsInstance(db, mongomock.Database) + + @patch(solve_path('db_config')) + def test_get_collection(self, mock_config: Mock): - db_connection, db = get_db(MongoTestCase.server, "test_db") + db_connection, coll = self.mongo_singleton_mock.get_collection( + 'test_coll' + ) - self.assertIsInstance(db_connection, MongoClient) - self.assertIsInstance(db, Database) + self.assertIsInstance(db_connection, mongomock.MongoClient) + self.assertIsInstance(coll, mongomock.Collection) diff --git a/tests/interfaces/__init__.py b/tests/interfaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/interfaces/test_cmodels.py b/tests/interfaces/test_cmodels.py new file mode 100644 index 0000000..99d74f7 --- /dev/null +++ b/tests/interfaces/test_cmodels.py @@ -0,0 +1,119 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +import mongomock +from mongomock import patch as db_path +from pymongo.results import InsertOneResult + +from src.interfaces.cmodels import CModelsInterface +from src.models.db.cmodels import CompartmentalModelEnum +from src.db.mongo import MongoClientSingleton + + +def solve_path(path: str): + source = 'src.config' + return ".".join([source, path]) + + +class CModelsInterfaceTestCase(TestCase): + + @db_path(servers=(('server.example.com', 27017),)) + def setUp(self): + self.connection_mock = mongomock.MongoClient('server.example.com') + self.db_mock = self.connection_mock.db + self.collection_mock = self.db_mock.collection + self.mongo_singleton_mock = MongoClientSingleton( + db_connection=self.connection_mock, + db=self.db_mock, + coll=self.collection_mock + ) + self.cmodels_interface_mock = CModelsInterface( + self.mongo_singleton_mock + ) + self.cmodel_example_document = CompartmentalModelEnum.sir.value + + def tearDown(self): + self.mongo_singleton_mock.coll.drop() + self.connection_mock.close() + + @patch(solve_path('db_config')) + def test_insert_one_cmodel_document_ok(self, mock: Mock): + + result = self.cmodels_interface_mock.insert_one_cmodel_document( + self.cmodel_example_document + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, InsertOneResult) + + @patch(solve_path('db_config')) + def test_insert_one_cmodel_document_exists(self, mock: Mock): + + self.cmodels_interface_mock.insert_one_cmodel_document( + self.cmodel_example_document + ) + + result = self.cmodels_interface_mock.insert_one_cmodel_document( + self.cmodel_example_document + ) + + pruned_example_document = CModelsInterface._prune_db_document( + self.cmodel_example_document.dict(by_alias=True) + ) + + self.assertEqual( + result, + pruned_example_document + ) + + @patch(solve_path('db_config')) + def test_insert_one_cmodel_document_update(self, mock: Mock): + + self.cmodels_interface_mock.insert_one_cmodel_document( + self.cmodel_example_document + ) + + self.cmodel_example_document.state_variables = ['S', 'I'] + + result = self.cmodels_interface_mock.insert_one_cmodel_document( + self.cmodel_example_document + ) + + self.assertIsNotNone(result) + self.assertTrue(result) + + @patch(solve_path('db_config')) + def test_prune_db_document(self, mock: Mock): + + _id = self.cmodel_example_document.id + + self.cmodels_interface_mock.insert_one_cmodel_document( + self.cmodel_example_document + ) + + read_model = self.cmodels_interface_mock.crud.read(_id) + pruned_document = self.cmodels_interface_mock._prune_db_document(read_model) + + self.assertIsNotNone(pruned_document) + + try: + pruned_document['inserted_at'] + except KeyError: + self.assertTrue(True) + else: + self.fail('inserted_at key not expected') + + try: + pruned_document['updated_at'] + except KeyError: + self.assertTrue(True) + else: + self.fail('updated_at key not expected') + + @patch(solve_path('db_config')) + def test_insert_all_models(self, mock: Mock): + + result = self.cmodels_interface_mock.insert_all_cmodel_documents() + + self.assertIsNotNone(result) + self.assertIsInstance(result[0], InsertOneResult) diff --git a/tests/interfaces/test_crud.py b/tests/interfaces/test_crud.py new file mode 100644 index 0000000..7c3368c --- /dev/null +++ b/tests/interfaces/test_crud.py @@ -0,0 +1,108 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +import mongomock +from mongomock import patch as db_path +from pymongo.results import InsertOneResult, DeleteResult + +from src.interfaces.crud import MongoCRUD +from src.models.db.cmodels import CompartmentalModelEnum +from src.db.mongo import MongoClientSingleton + + +def solve_path(path: str): + source = 'src.config' + return ".".join([source, path]) + + +class MongoCRUDTestCase(TestCase): + + @db_path(servers=(('server.example.com', 27017),)) + def setUp(self): + self.connection_mock = mongomock.MongoClient('server.example.com') + self.db_mock = self.connection_mock.db + self.collection_mock = self.db_mock.collection + self.mongo_singleton_mock = MongoClientSingleton( + db_connection=self.connection_mock, + db=self.db_mock, + coll=self.collection_mock + ) + self.mongo_crud_mock = MongoCRUD(self.mongo_singleton_mock) + self._id_example = CompartmentalModelEnum.sir.value.id + self.model_example = {'_id': self._id_example} + + def tearDown(self): + self.mongo_singleton_mock.coll.drop() + self.connection_mock.close() + + @patch(solve_path('db_config')) + def test_insert_ok(self, mock: Mock): + result = self.mongo_crud_mock.insert(self.model_example) + + self.assertIsNotNone(result) + self.assertIsInstance(result, InsertOneResult) + + @patch(solve_path('db_config')) + def test_insert_no_id_present_in_document(self, mock: Mock): + try: + self.mongo_crud_mock.insert( + {'no_id_present': 'not_a_valid_query'} + ) + except ValueError: + self.assertTrue(True) + else: + self.fail('_id must be present in inserted document') + + @patch(solve_path('db_config')) + def test_insert_existent_document(self, mock: Mock): + self.mongo_crud_mock.insert(self.model_example) + try: + self.mongo_crud_mock.insert(self.model_example) + except ValueError: + self.assertTrue(True) + else: + self.fail('_id must be present in inserted document') + + @patch(solve_path('db_config')) + def test_read_ok(self, mock: Mock): + self.mongo_crud_mock.insert(self.model_example) + read_result = self.mongo_crud_mock.read(self._id_example) + + self.assertIsNotNone(read_result) + self.assertIsInstance(read_result, dict) + + @patch(solve_path('db_config')) + def test_read_not_found(self, mock: Mock): + result = self.mongo_crud_mock.read(self._id_example) + self.assertIsNone(result) + + @patch(solve_path('db_config')) + def test_update_cmodel_state_ok(self, mock: Mock): + self.mongo_crud_mock.insert(self.model_example) + new_data = {'params': ['a', 'b', 'c']} + + result = self.mongo_crud_mock.update( + self._id_example, + new_data + ) + + self.assertTrue(result) + + @patch(solve_path('db_config')) + def test_update_state_fail(self, mock: Mock): + result = self.mongo_crud_mock.update(self._id_example, {}) + + self.assertFalse(result) + + @patch(solve_path('db_config')) + def test_update_state_no_query(self, mock: Mock): + result = self.mongo_crud_mock.update(None, {}) + self.assertFalse(result) + + @patch(solve_path('db_config')) + def test_delete_state_ok(self, mock: Mock): + self.mongo_crud_mock.insert(self.model_example) + result = self.mongo_crud_mock.delete(self._id_example) + + self.assertIsNotNone(result) + self.assertIsInstance(result, DeleteResult) diff --git a/tests/models/db/__init__.py b/tests/models/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/db/test_base_model.py b/tests/models/db/test_base_model.py new file mode 100644 index 0000000..9af1743 --- /dev/null +++ b/tests/models/db/test_base_model.py @@ -0,0 +1,13 @@ +from pydantic import ValidationError +import pytest + +from src.models.db.base_model import MetadataBaseDoc + + +def test_SimulationConfig(): + try: + MetadataBaseDoc(id='not a bson.ObjectID') + except ValidationError: + assert True + else: + assert pytest.fail('MetadataBaseDoc must contain id field') diff --git a/tests/models/routers/test_cmodel.py b/tests/models/db/test_cmodels.py similarity index 54% rename from tests/models/routers/test_cmodel.py rename to tests/models/db/test_cmodels.py index 709c319..59ee36e 100644 --- a/tests/models/routers/test_cmodel.py +++ b/tests/models/db/test_cmodels.py @@ -1,10 +1,10 @@ +import pytest from hypothesis import given, strategies as st -from src.models.routers.cmodel import ( +from src.models.db.cmodels import ( CompartmentalModelBase, - CompartmentalModel, - CModel, - AllCModels, + CompartmentalModelEnum, + CModel ) @@ -17,17 +17,11 @@ def test_CompartmentalModelBase_properties(instance: CompartmentalModelBase): assert isinstance(instance.parameters_units, dict) -def test_CompartmentalModel(): - for model in CompartmentalModel: - assert isinstance(model.value, CompartmentalModelBase) +@pytest.mark.parametrize('model', [model for model in CompartmentalModelEnum]) +def test_CompartmentalModel(model: CompartmentalModelEnum): + assert isinstance(model.value, CompartmentalModelBase) @given(st.builds(CModel)) def test_CModel(instance: CModel): - assert isinstance(instance.model, CompartmentalModel) - - -@given(st.builds(AllCModels)) -def test_AllCmodels(instance: AllCModels): - for model in instance.models: - assert isinstance(model, CompartmentalModelBase) + assert isinstance(instance.model, CompartmentalModelEnum) diff --git a/tests/models/routers/test_simulation.py b/tests/models/db/test_simulations.py similarity index 69% rename from tests/models/routers/test_simulation.py rename to tests/models/db/test_simulations.py index 3b2fe0b..77e3577 100644 --- a/tests/models/routers/test_simulation.py +++ b/tests/models/db/test_simulations.py @@ -1,8 +1,10 @@ from datetime import datetime +import pytest +from pydantic import ValidationError from hypothesis import given, strategies as st -from src.models.routers.simulation import ( +from src.models.db.simulations import ( SimulationType, SimulationConfig, PydanticObjectId @@ -22,3 +24,9 @@ def test_SimulationConfig(instance: SimulationConfig): assert isinstance(instance.name, str) assert isinstance(instance.creation_date, datetime) assert isinstance(instance.simulation_type, SimulationType) + + +@given(st.builds(SimulationConfig)) +def test_SimulationConfig_bad_request(instance: SimulationConfig): + with pytest.raises(ValidationError): + SimulationConfig(**{'bad_field_name': 'random value'}) diff --git a/tests/use_cases/__init__.py b/tests/use_cases/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/use_cases/test_cmodels.py b/tests/use_cases/test_cmodels.py new file mode 100644 index 0000000..65701a7 --- /dev/null +++ b/tests/use_cases/test_cmodels.py @@ -0,0 +1,39 @@ +from unittest import TestCase +from unittest.mock import patch, Mock + +import mongomock +from mongomock import patch as db_path + +from src.use_cases.cmodels import CmodelUseCases +from src.interfaces.crud import MongoCRUD +from src.models.db.cmodels import CompartmentalModelEnum +from src.db.mongo import MongoClientSingleton + + +class CmodelUseCasesTestCase(TestCase): + + @db_path(servers=(('server.example.com', 27017),)) + def setUp(self): + self.connection_mock = mongomock.MongoClient('server.example.com') + self.db_mock = self.connection_mock.db + self.collection_mock = self.db_mock.collection + self.mongo_singleton_mock = MongoClientSingleton( + db_connection=self.connection_mock, + db=self.db_mock, + coll=self.collection_mock + ) + self.mongo_crud_mock = MongoCRUD( + self.mongo_singleton_mock + ) + self.cmodel_use_cases = CmodelUseCases(self.mongo_singleton_mock) + + def tearDown(self): + self.collection_mock.drop() + self.connection_mock.close() + + @patch('src.config.db_config') + def test_cmodels_in_ok(self, mock: Mock): + self.cmodel_use_cases.update_cmodels_collection() + for model in CompartmentalModelEnum.values(): + result = self.mongo_crud_mock.read(model.id) + self.assertIsNotNone(result)