Skip to content

Commit

Permalink
chore: add model_type validations
Browse files Browse the repository at this point in the history
  • Loading branch information
bigmoby committed Jun 29, 2024
1 parent 64218fa commit 4de7c81
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 17 deletions.
83 changes: 76 additions & 7 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Self
from uuid import UUID

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic.alias_generators import to_camel

from app.db.dao.model_dao import Model
from app.models.inferred_schema_dto import SupportedTypes
from app.models.utils import is_none, is_number, is_number_or_string, is_optional_float


class ModelType(str, Enum):
Expand All @@ -28,26 +30,25 @@ class Granularity(str, Enum):
MONTH = 'MONTH'


class ColumnDefinition(BaseModel):
class ColumnDefinition(BaseModel, validate_assignment=True):
name: str
type: str
type: SupportedTypes

def to_dict(self):
return self.model_dump()


class OutputType(BaseModel):
class OutputType(BaseModel, validate_assignment=True):
prediction: ColumnDefinition
prediction_proba: Optional[ColumnDefinition] = None
output: List[ColumnDefinition]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

def to_dict(self):
return self.model_dump()


class ModelIn(BaseModel):
class ModelIn(BaseModel, validate_assignment=True):
name: str
description: Optional[str] = None
model_type: ModelType
Expand All @@ -64,6 +65,74 @@ class ModelIn(BaseModel):
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@model_validator(mode='after')
def validate_target(self) -> Self:
checked_model_type: ModelType = self.model_type
match checked_model_type:
case ModelType.BINARY:
if not is_number(self.target.type):
raise ValueError(
f'target must be a number for a ModelType.BINARY, has been provided [{self.target}]'
)
return self
case ModelType.MULTI_CLASS:
if not is_number_or_string(self.target.type):
raise ValueError(
f'target must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.target}]'
)
return self
case ModelType.REGRESSION:
if not is_number(self.target.type):
raise ValueError(
f'target must be a number for a ModelType.REGRESSION, has been provided [{self.target}]'
)
return self
case _:
raise ValueError('not supported type for model_type')

@model_validator(mode='after')
def validate_outputs(self) -> Self:
checked_model_type: ModelType = self.model_type
match checked_model_type:
case ModelType.BINARY:
if not is_number(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.MULTI_CLASS:
if not is_number_or_string(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.REGRESSION:
if not is_number(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]'
)
if not is_none(self.outputs.prediction_proba.type):
raise ValueError(
f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]'
)
return self
case _:
raise ValueError('not supported type for model_type')

@model_validator(mode='after')
def timestamp_must_be_datetime(self) -> Self:
if not self.timestamp.type == SupportedTypes.datetime:
raise ValueError('timestamp must be a datetime')
return self

def to_model(self) -> Model:
now = datetime.datetime.now(tz=datetime.UTC)
return Model(
Expand Down
34 changes: 24 additions & 10 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from app.db.tables.reference_dataset_metrics_table import ReferenceDatasetMetrics
from app.db.tables.reference_dataset_table import ReferenceDataset
from app.models.job_status import JobStatus
from app.models.model_dto import DataType, Granularity, ModelIn, ModelType
from app.models.model_dto import (
ColumnDefinition,
DataType,
Granularity,
ModelIn,
ModelType,
OutputType,
SupportedTypes,
)

MODEL_UUID = uuid.uuid4()
REFERENCE_UUID = uuid.uuid4()
Expand All @@ -26,7 +34,7 @@ def get_sample_model(
features: List[Dict] = [{'name': 'feature1', 'type': 'string'}],
outputs: Dict = {
'prediction': {'name': 'pred1', 'type': 'int'},
'prediction_proba': {'name': 'prob1', 'type': 'double'},
'prediction_proba': {'name': 'prob1', 'type': 'float'},
'output': [{'name': 'output1', 'type': 'string'}],
},
target: Dict = {'name': 'target1', 'type': 'string'},
Expand Down Expand Up @@ -59,14 +67,20 @@ def get_sample_model_in(
model_type: str = ModelType.BINARY.value,
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[Dict] = [{'name': 'feature1', 'type': 'string'}],
outputs: Dict = {
'prediction': {'name': 'pred1', 'type': 'int'},
'prediction_proba': {'name': 'prob1', 'type': 'double'},
'output': [{'name': 'output1', 'type': 'string'}],
},
target: Dict = {'name': 'target1', 'type': 'string'},
timestamp: Dict = {'name': 'timestamp', 'type': 'datetime'},
features: List[ColumnDefinition] = [
ColumnDefinition(name='feature1', type=SupportedTypes.string)
],
outputs: OutputType = OutputType(
prediction=ColumnDefinition(name='pred1', type=SupportedTypes.int),
prediction_proba=ColumnDefinition(name='prob1', type=SupportedTypes.float),
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
),
target: ColumnDefinition = ColumnDefinition(
name='target1', type=SupportedTypes.string
),
timestamp: ColumnDefinition = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime
),
frameworks: Optional[str] = None,
algorithm: Optional[str] = None,
):
Expand Down

0 comments on commit 4de7c81

Please sign in to comment.