Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add latest reference and current uuids to modelOut dto #32

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions api/app/db/dao/reference_dataset_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ def insert_reference_dataset(
return reference_dataset

def get_reference_dataset_by_model_uuid(
self, uuid: UUID
self, model_uuid: UUID
) -> Optional[ReferenceDataset]:
with self.db.begin_session() as session:
return (
session.query(ReferenceDataset)
.where(ReferenceDataset.model_uuid == uuid)
.where(ReferenceDataset.model_uuid == model_uuid)
.one_or_none()
)

def get_latest_reference_dataset_by_model_uuid(
self, model_uuid: UUID
) -> Optional[ReferenceDataset]:
with self.db.begin_session() as session:
return (
session.query(ReferenceDataset)
.order_by(desc(ReferenceDataset.date))
.where(ReferenceDataset.model_uuid == model_uuid)
.limit(1)
.one_or_none()
)

Expand Down
6 changes: 5 additions & 1 deletion api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@
current_dataset_dao = CurrentDatasetDAO(database)
current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database)

model_service = ModelService(model_dao)
model_service = ModelService(
model_dao=model_dao,
reference_dataset_dao=reference_dataset_dao,
current_dataset_dao=current_dataset_dao,
)
s3_config = get_config().s3_config

if s3_config.s3_endpoint_url is not None:
Expand Down
10 changes: 9 additions & 1 deletion api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,19 @@ class ModelOut(BaseModel):
algorithm: Optional[str]
created_at: str
updated_at: str
latest_reference_uuid: Optional[UUID]
latest_current_uuid: Optional[UUID]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@staticmethod
def from_model(model: Model):
def from_model(
model: Model,
latest_reference_uuid: Optional[UUID] = None,
latest_current_uuid: Optional[UUID] = None,
):
return ModelOut(
uuid=model.uuid,
name=model.name,
Expand All @@ -120,4 +126,6 @@ def from_model(model: Model):
algorithm=model.algorithm,
created_at=str(model.created_at),
updated_at=str(model.updated_at),
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)
65 changes: 61 additions & 4 deletions api/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@

from fastapi_pagination import Page, Params

from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.db.tables.model_table import Model
from app.models.exceptions import ModelInternalError, ModelNotFoundError
from app.models.model_dto import ModelIn, ModelOut
from app.models.model_order import OrderType


class ModelService:
def __init__(self, model_dao: ModelDAO):
def __init__(
self,
model_dao: ModelDAO,
reference_dataset_dao: ReferenceDatasetDAO,
current_dataset_dao: CurrentDatasetDAO,
):
self.model_dao = model_dao
self.rd_dao = reference_dataset_dao
self.cd_dao = current_dataset_dao

def create_model(self, model_in: ModelIn) -> ModelOut:
try:
Expand All @@ -26,7 +35,14 @@ def create_model(self, model_in: ModelIn) -> ModelOut:

def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
return ModelOut.from_model(model)
latest_reference_uuid, latest_current_uuid = self.get_latest_dataset_uuids(
model_uuid
)
return ModelOut.from_model(
model=model,
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)

def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
Expand All @@ -37,7 +53,18 @@ def get_all_models(
self,
) -> List[ModelOut]:
models = self.model_dao.get_all()
return [ModelOut.from_model(model) for model in models]
model_out_list = []
for model in models:
latest_reference_uuid, latest_current_uuid = self.get_latest_dataset_uuids(
model.uuid
)
model_out = ModelOut.from_model(
model=model,
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)
model_out_list.append(model_out)
return model_out_list

def get_all_models_paginated(
self,
Expand All @@ -48,7 +75,18 @@ def get_all_models_paginated(
models: Page[Model] = self.model_dao.get_all_paginated(
params=params, order=order, sort=sort
)
_items = [ModelOut.from_model(model) for model in models.items]

_items = []
for model in models.items:
latest_reference_uuid, latest_current_uuid = self.get_latest_dataset_uuids(
model.uuid
)
model_out = ModelOut.from_model(
model=model,
latest_reference_uuid=latest_reference_uuid,
latest_current_uuid=latest_current_uuid,
)
_items.append(model_out)

return Page.create(items=_items, params=params, total=models.total)

Expand All @@ -57,3 +95,22 @@ def check_and_get_model(self, model_uuid: UUID) -> Model:
if not model:
raise ModelNotFoundError(f'Model {model_uuid} not found')
return model

def get_latest_dataset_uuids(
self, model_uuid: UUID
) -> (Optional[UUID], Optional[UUID]):
latest_reference_dataset = (
self.rd_dao.get_latest_reference_dataset_by_model_uuid(model_uuid)
)
latest_current_dataset = self.cd_dao.get_latest_current_dataset_by_model_uuid(
model_uuid
)

latest_reference_uuid = (
latest_reference_dataset.uuid if latest_reference_dataset else None
)
latest_current_uuid = (
latest_current_dataset.uuid if latest_current_dataset else None
)

return latest_reference_uuid, latest_current_uuid
48 changes: 39 additions & 9 deletions api/tests/dao/reference_dataset_dao_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ReferenceDatasetDAOTest(DatabaseIntegration):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.f_reference_dataset_dao = ReferenceDatasetDAO(cls.db)
cls.reference_dataset_dao = ReferenceDatasetDAO(cls.db)
cls.model_dao = ModelDAO(cls.db)

def test_insert_reference_dataset_upload_result(self):
Expand All @@ -26,7 +26,7 @@ def test_insert_reference_dataset_upload_result(self):
date=datetime.datetime.now(tz=datetime.UTC),
)

inserted = self.f_reference_dataset_dao.insert_reference_dataset(to_insert)
inserted = self.reference_dataset_dao.insert_reference_dataset(to_insert)
assert inserted == to_insert

def test_get_reference_dataset_by_model_uuid(self):
Expand All @@ -38,14 +38,45 @@ def test_get_reference_dataset_by_model_uuid(self):
date=datetime.datetime.now(tz=datetime.UTC),
)

inserted = self.f_reference_dataset_dao.insert_reference_dataset(to_insert)
retrieved = self.f_reference_dataset_dao.get_reference_dataset_by_model_uuid(
inserted = self.reference_dataset_dao.insert_reference_dataset(to_insert)
retrieved = self.reference_dataset_dao.get_reference_dataset_by_model_uuid(
inserted.model_uuid
)
assert inserted.uuid == retrieved.uuid
assert inserted.model_uuid == retrieved.model_uuid
assert inserted.path == retrieved.path

def test_get_latest_reference_dataset_by_model_uuid(self):
model = self.model_dao.insert(db_mock.get_sample_model())
reference_one = ReferenceDataset(
uuid=uuid4(),
model_uuid=model.uuid,
path='frank_file.csv',
date=datetime.datetime.now(tz=datetime.UTC),
)

self.reference_dataset_dao.insert_reference_dataset(reference_one)

reference_two = ReferenceDataset(
uuid=uuid4(),
model_uuid=model.uuid,
path='frank_file.csv',
date=datetime.datetime.now(tz=datetime.UTC),
)

inserted_two = self.reference_dataset_dao.insert_reference_dataset(
reference_two
)

retrieved = (
self.reference_dataset_dao.get_latest_reference_dataset_by_model_uuid(
model.uuid
)
)
assert inserted_two.uuid == retrieved.uuid
assert inserted_two.model_uuid == retrieved.model_uuid
assert inserted_two.path == retrieved.path

def test_get_all_reference_datasets_by_model_uuid_paginated(self):
model = self.model_dao.insert(db_mock.get_sample_model())
reference_upload_1 = ReferenceDataset(
Expand All @@ -66,17 +97,17 @@ def test_get_all_reference_datasets_by_model_uuid_paginated(self):
path='frank_file.csv',
date=datetime.datetime.now(tz=datetime.UTC),
)
inserted_1 = self.f_reference_dataset_dao.insert_reference_dataset(
inserted_1 = self.reference_dataset_dao.insert_reference_dataset(
reference_upload_1
)
inserted_2 = self.f_reference_dataset_dao.insert_reference_dataset(
inserted_2 = self.reference_dataset_dao.insert_reference_dataset(
reference_upload_2
)
inserted_3 = self.f_reference_dataset_dao.insert_reference_dataset(
inserted_3 = self.reference_dataset_dao.insert_reference_dataset(
reference_upload_3
)

retrieved = self.f_reference_dataset_dao.get_all_reference_datasets_by_model_uuid_paginated(
retrieved = self.reference_dataset_dao.get_all_reference_datasets_by_model_uuid_paginated(
model.uuid, Params(page=1, size=10)
)

Expand All @@ -93,4 +124,3 @@ def test_get_all_reference_datasets_by_model_uuid_paginated(self):
assert inserted_3.path == retrieved.items[2].path

assert len(retrieved.items) == 3

56 changes: 53 additions & 3 deletions api/tests/services/model_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from fastapi_pagination import Page, Params
import pytest

from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.models.exceptions import ModelNotFoundError
from app.models.model_dto import ModelOut
from app.models.model_order import OrderType
Expand All @@ -17,8 +19,14 @@ class ModelServiceTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_dao: ModelDAO = MagicMock(spec_set=ModelDAO)
cls.model_service = ModelService(cls.model_dao)
cls.mocks = [cls.model_dao]
cls.rd_dao: ReferenceDatasetDAO = MagicMock(spec_set=ReferenceDatasetDAO)
cls.cd_dao: CurrentDatasetDAO = MagicMock(spec_set=CurrentDatasetDAO)
cls.model_service = ModelService(
model_dao=cls.model_dao,
reference_dataset_dao=cls.rd_dao,
current_dataset_dao=cls.cd_dao,
)
cls.mocks = [cls.model_dao, cls.rd_dao, cls.cd_dao]

def test_create_model_ok(self):
model = db_mock.get_sample_model()
Expand All @@ -31,11 +39,25 @@ def test_create_model_ok(self):

def test_get_model_by_uuid_ok(self):
model = db_mock.get_sample_model()
reference_dataset = db_mock.get_sample_reference_dataset(model_uuid=model.uuid)
current_dataset = db_mock.get_sample_current_dataset(model_uuid=model.uuid)
self.model_dao.get_by_uuid = MagicMock(return_value=model)
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=reference_dataset
)
self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock(
return_value=current_dataset
)
res = self.model_service.get_model_by_uuid(model_uuid)
self.model_dao.get_by_uuid.assert_called_once()
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once()
self.cd_dao.get_latest_current_dataset_by_model_uuid.assert_called_once()

assert res == ModelOut.from_model(model)
assert res == ModelOut.from_model(
model=model,
latest_reference_uuid=reference_dataset.uuid,
latest_current_uuid=current_dataset.uuid,
)

def test_get_model_by_uuid_not_found(self):
self.model_dao.get_by_uuid = MagicMock(return_value=None)
Expand Down Expand Up @@ -66,6 +88,12 @@ def test_get_all_models_paginated_ok(self):
sort=None,
)
self.model_dao.get_all_paginated = MagicMock(return_value=page)
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock(
return_value=None
)

result = self.model_service.get_all_models_paginated(
params=Params(page=1, size=10), order=OrderType.ASC, sort=None
Expand All @@ -81,5 +109,27 @@ def test_get_all_models_paginated_ok(self):
assert result.items[1].name == 'model2'
assert result.items[2].name == 'model3'

def test_get_all_models_ok(self):
model1 = db_mock.get_sample_model(id=1, uuid=uuid.uuid4(), name='model1')
model2 = db_mock.get_sample_model(id=2, uuid=uuid.uuid4(), name='model2')
model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='model3')
sample_models = [model1, model2, model3]
self.model_dao.get_all = MagicMock(return_value=sample_models)
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.cd_dao.get_latest_current_dataset_by_model_uuid = MagicMock(
return_value=None
)

result = self.model_service.get_all_models()

self.model_dao.get_all.assert_called_once()

assert len(result) == 3
assert result[0].name == 'model1'
assert result[1].name == 'model2'
assert result[2].name == 'model3'


model_uuid = db_mock.MODEL_UUID
2 changes: 2 additions & 0 deletions sdk/radicalbit_platform_sdk/models/model_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,7 @@ class ModelDefinition(BaseModelDefinition):
uuid: uuid_lib.UUID = Field(default_factory=lambda: uuid_lib.uuid4())
created_at: str = Field(alias='createdAt')
updated_at: str = Field(alias='updatedAt')
latest_reference_uuid: Optional[uuid_lib.UUID] = Field(alias='latestReferenceUuid')
latest_current_uuid: Optional[uuid_lib.UUID] = Field(alias='latestCurrentUuid')

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
Loading