Skip to content

Commit

Permalink
feat: add latest reference and current uuids to modelOut dto (#32)
Browse files Browse the repository at this point in the history
* feat: add latest reference and current uuids to modelOut dto, edit tests

* feat: align with main, add get latest dataset method to get all models api

* feat: add latest reference and current uuids to model definition
  • Loading branch information
dtria91 authored Jun 26, 2024
1 parent 683f1d3 commit 0780b0e
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 20 deletions.
16 changes: 14 additions & 2 deletions app/db/dao/reference_dataset_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ def insert_reference_dataset(
return reference_dataset

def get_reference_dataset_by_model_uuid(
self, uuid: UUID
self, model_uuid: UUID
) -> Optional[ReferenceDataset]:
with self.db.begin_session() as session:
return (
session.query(ReferenceDataset)
.where(ReferenceDataset.model_uuid == uuid)
.where(ReferenceDataset.model_uuid == model_uuid)
.one_or_none()
)

def get_latest_reference_dataset_by_model_uuid(
self, model_uuid: UUID
) -> Optional[ReferenceDataset]:
with self.db.begin_session() as session:
return (
session.query(ReferenceDataset)
.order_by(desc(ReferenceDataset.date))
.where(ReferenceDataset.model_uuid == model_uuid)
.limit(1)
.one_or_none()
)

Expand Down
6 changes: 5 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@
current_dataset_dao = CurrentDatasetDAO(database)
current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database)

model_service = ModelService(model_dao)
model_service = ModelService(
model_dao=model_dao,
reference_dataset_dao=reference_dataset_dao,
current_dataset_dao=current_dataset_dao,
)
s3_config = get_config().s3_config

if s3_config.s3_endpoint_url is not None:
Expand Down
10 changes: 9 additions & 1 deletion app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,19 @@ class ModelOut(BaseModel):
algorithm: Optional[str]
created_at: str
updated_at: str
latest_reference_uuid: Optional[UUID]
latest_current_uuid: Optional[UUID]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@staticmethod
def from_model(model: Model):
def from_model(
model: Model,
latest_reference_uuid: Optional[UUID] = None,
latest_current_uuid: Optional[UUID] = None,
):
return ModelOut(
uuid=model.uuid,
name=model.name,
Expand All @@ -120,4 +126,6 @@ def from_model(model: Model):
algorithm=model.algorithm,
created_at=str(model.created_at),
updated_at=str(model.updated_at),
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)
65 changes: 61 additions & 4 deletions app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@

from fastapi_pagination import Page, Params

from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.db.tables.model_table import Model
from app.models.exceptions import ModelInternalError, ModelNotFoundError
from app.models.model_dto import ModelIn, ModelOut
from app.models.model_order import OrderType


class ModelService:
def __init__(self, model_dao: ModelDAO):
def __init__(
self,
model_dao: ModelDAO,
reference_dataset_dao: ReferenceDatasetDAO,
current_dataset_dao: CurrentDatasetDAO,
):
self.model_dao = model_dao
self.rd_dao = reference_dataset_dao
self.cd_dao = current_dataset_dao

def create_model(self, model_in: ModelIn) -> ModelOut:
try:
Expand All @@ -26,7 +35,14 @@ def create_model(self, model_in: ModelIn) -> ModelOut:

def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
return ModelOut.from_model(model)
latest_reference_uuid, latest_current_uuid = self.get_latest_dataset_uuids(
model_uuid
)
return ModelOut.from_model(
model=model,
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)

def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
Expand All @@ -37,7 +53,18 @@ def get_all_models(
self,
) -> List[ModelOut]:
models = self.model_dao.get_all()
return [ModelOut.from_model(model) for model in models]
model_out_list = []
for model in models:
latest_reference_uuid, latest_current_uuid = self.get_latest_dataset_uuids(
model.uuid
)
model_out = ModelOut.from_model(
model=model,
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)
model_out_list.append(model_out)
return model_out_list

def get_all_models_paginated(
self,
Expand All @@ -48,7 +75,18 @@ def get_all_models_paginated(
models: Page[Model] = self.model_dao.get_all_paginated(
params=params, order=order, sort=sort
)
_items = [ModelOut.from_model(model) for model in models.items]

_items = []
for model in models.items:
latest_reference_uuid, latest_current_uuid = self.get_latest_dataset_uuids(
model.uuid
)
model_out = ModelOut.from_model(
model=model,
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)
_items.append(model_out)

return Page.create(items=_items, params=params, total=models.total)

Expand All @@ -57,3 +95,22 @@ def check_and_get_model(self, model_uuid: UUID) -> Model:
if not model:
raise ModelNotFoundError(f'Model {model_uuid} not found')
return model

def get_latest_dataset_uuids(
self, model_uuid: UUID
) -> (Optional[UUID], Optional[UUID]):
latest_reference_dataset = (
self.rd_dao.get_latest_reference_dataset_by_model_uuid(model_uuid)
)
latest_current_dataset = self.cd_dao.get_latest_current_dataset_by_model_uuid(
model_uuid
)

latest_reference_uuid = (
latest_reference_dataset.uuid if latest_reference_dataset else None
)
latest_current_uuid = (
latest_current_dataset.uuid if latest_current_dataset else None
)

return latest_reference_uuid, latest_current_uuid
48 changes: 39 additions & 9 deletions tests/dao/reference_dataset_dao_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ReferenceDatasetDAOTest(DatabaseIntegration):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.f_reference_dataset_dao = ReferenceDatasetDAO(cls.db)
cls.reference_dataset_dao = ReferenceDatasetDAO(cls.db)
cls.model_dao = ModelDAO(cls.db)

def test_insert_reference_dataset_upload_result(self):
Expand All @@ -26,7 +26,7 @@ def test_insert_reference_dataset_upload_result(self):
date=datetime.datetime.now(tz=datetime.UTC),
)

inserted = self.f_reference_dataset_dao.insert_reference_dataset(to_insert)
inserted = self.reference_dataset_dao.insert_reference_dataset(to_insert)
assert inserted == to_insert

def test_get_reference_dataset_by_model_uuid(self):
Expand All @@ -38,14 +38,45 @@ def test_get_reference_dataset_by_model_uuid(self):
date=datetime.datetime.now(tz=datetime.UTC),
)

inserted = self.f_reference_dataset_dao.insert_reference_dataset(to_insert)
retrieved = self.f_reference_dataset_dao.get_reference_dataset_by_model_uuid(
inserted = self.reference_dataset_dao.insert_reference_dataset(to_insert)
retrieved = self.reference_dataset_dao.get_reference_dataset_by_model_uuid(
inserted.model_uuid
)
assert inserted.uuid == retrieved.uuid
assert inserted.model_uuid == retrieved.model_uuid
assert inserted.path == retrieved.path

def test_get_latest_reference_dataset_by_model_uuid(self):
model = self.model_dao.insert(db_mock.get_sample_model())
reference_one = ReferenceDataset(
uuid=uuid4(),
model_uuid=model.uuid,
path='frank_file.csv',
date=datetime.datetime.now(tz=datetime.UTC),
)

self.reference_dataset_dao.insert_reference_dataset(reference_one)

reference_two = ReferenceDataset(
uuid=uuid4(),
model_uuid=model.uuid,
path='frank_file.csv',
date=datetime.datetime.now(tz=datetime.UTC),
)

inserted_two = self.reference_dataset_dao.insert_reference_dataset(
reference_two
)

retrieved = (
self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid(
model.uuid
)
)
assert inserted_two.uuid == retrieved.uuid
assert inserted_two.model_uuid == retrieved.model_uuid
assert inserted_two.path == retrieved.path

def test_get_all_reference_datasets_by_model_uuid_paginated(self):
model = self.model_dao.insert(db_mock.get_sample_model())
reference_upload_1 = ReferenceDataset(
Expand All @@ -66,17 +97,17 @@ def test_get_all_reference_datasets_by_model_uuid_paginated(self):
path='frank_file.csv',
date=datetime.datetime.now(tz=datetime.UTC),
)
inserted_1 = self.f_reference_dataset_dao.insert_reference_dataset(
inserted_1 = self.reference_dataset_dao.insert_reference_dataset(
reference_upload_1
)
inserted_2 = self.f_reference_dataset_dao.insert_reference_dataset(
inserted_2 = self.reference_dataset_dao.insert_reference_dataset(
reference_upload_2
)
inserted_3 = self.f_reference_dataset_dao.insert_reference_dataset(
inserted_3 = self.reference_dataset_dao.insert_reference_dataset(
reference_upload_3
)

retrieved = self.f_reference_dataset_dao.get_all_reference_datasets_by_model_uuid_paginated(
retrieved = self.reference_dataset_dao.get_all_reference_datasets_by_model_uuid_paginated(
model.uuid, Params(page=1, size=10)
)

Expand All @@ -93,4 +124,3 @@ def test_get_all_reference_datasets_by_model_uuid_paginated(self):
assert inserted_3.path == retrieved.items[2].path

assert len(retrieved.items) == 3

56 changes: 53 additions & 3 deletions tests/services/model_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from fastapi_pagination import Page, Params
import pytest

from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.models.exceptions import ModelNotFoundError
from app.models.model_dto import ModelOut
from app.models.model_order import OrderType
Expand All @@ -17,8 +19,14 @@ class ModelServiceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_dao: ModelDAO = MagicMock(spec_set=ModelDAO)
cls.model_service = ModelService(cls.model_dao)
cls.mocks = [cls.model_dao]
cls.rd_dao: ReferenceDatasetDAO = MagicMock(spec_set=ReferenceDatasetDAO)
cls.cd_dao: CurrentDatasetDAO = MagicMock(spec_set=CurrentDatasetDAO)
cls.model_service = ModelService(
model_dao=cls.model_dao,
reference_dataset_dao=cls.rd_dao,
current_dataset_dao=cls.cd_dao,
)
cls.mocks = [cls.model_dao, cls.rd_dao, cls.cd_dao]

def test_create_model_ok(self):
model = db_mock.get_sample_model()
Expand All @@ -31,11 +39,25 @@ def test_create_model_ok(self):

def test_get_model_by_uuid_ok(self):
model = db_mock.get_sample_model()
reference_dataset = db_mock.get_sample_reference_dataset(model_uuid=model.uuid)
current_dataset = db_mock.get_sample_current_dataset(model_uuid=model.uuid)
self.model_dao.get_by_uuid = MagicMock(return_value=model)
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=reference_dataset
)
self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock(
return_value=current_dataset
)
res = self.model_service.get_model_by_uuid(model_uuid)
self.model_dao.get_by_uuid.assert_called_once()
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once()
self.cd_dao.get_latest_current_dataset_by_model_uuid.assert_called_once()

assert res == ModelOut.from_model(model)
assert res == ModelOut.from_model(
model=model,
latest_reference_uuid=reference_dataset.uuid,
latest_current_uuid=current_dataset.uuid,
)

def test_get_model_by_uuid_not_found(self):
self.model_dao.get_by_uuid = MagicMock(return_value=None)
Expand Down Expand Up @@ -66,6 +88,12 @@ def test_get_all_models_paginated_ok(self):
sort=None,
)
self.model_dao.get_all_paginated = MagicMock(return_value=page)
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock(
return_value=None
)

result = self.model_service.get_all_models_paginated(
params=Params(page=1, size=10), order=OrderType.ASC, sort=None
Expand All @@ -81,5 +109,27 @@ def test_get_all_models_paginated_ok(self):
assert result.items[1].name == 'model2'
assert result.items[2].name == 'model3'

def test_get_all_models_ok(self):
model1 = db_mock.get_sample_model(id=1, uuid=uuid.uuid4(), name='model1')
model2 = db_mock.get_sample_model(id=2, uuid=uuid.uuid4(), name='model2')
model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='model3')
sample_models = [model1, model2, model3]
self.model_dao.get_all = MagicMock(return_value=sample_models)
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock(
return_value=None
)

result = self.model_service.get_all_models()

self.model_dao.get_all.assert_called_once()

assert len(result) == 3
assert result[0].name == 'model1'
assert result[1].name == 'model2'
assert result[2].name == 'model3'


model_uuid = db_mock.MODEL_UUID

0 comments on commit 0780b0e

Please sign in to comment.