Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactoring of metrics part #38

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 28 additions & 32 deletions api/app/models/metrics/data_quality_dto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -102,10 +102,10 @@ class ClassMetrics(BaseModel):
)


class BinaryClassDataQuality(BaseModel):
class ClassificationDataQuality(BaseModel):
n_observations: int
class_metrics: List[ClassMetrics]
feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]]
feature_metrics: List[NumericalFeatureMetrics | CategoricalFeatureMetrics]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -115,19 +115,13 @@ class BinaryClassDataQuality(BaseModel):
)


class MultiClassDataQuality(BaseModel):
pass


class RegressionDataQuality(BaseModel):
pass


class DataQualityDTO(BaseModel):
job_status: JobStatus
data_quality: Optional[
Union[BinaryClassDataQuality, MultiClassDataQuality, RegressionDataQuality]
]
data_quality: Optional[ClassificationDataQuality | RegressionDataQuality]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -141,30 +135,32 @@ def from_dict(
model_type: ModelType,
job_status: JobStatus,
data_quality_data: Optional[Dict],
):
) -> 'DataQualityDTO':
"""Create a DataQualityDTO from a dictionary of data."""
if not data_quality_data:
return DataQualityDTO(
job_status=job_status,
data_quality=None,
)
match model_type:
case ModelType.BINARY:
binary_class_data_quality = BinaryClassDataQuality(**data_quality_data)
return DataQualityDTO(
job_status=job_status,
data_quality=binary_class_data_quality,
)
case ModelType.MULTI_CLASS:
multi_class_data_quality = MultiClassDataQuality(**data_quality_data)
return DataQualityDTO(
job_status=job_status,
data_quality=multi_class_data_quality,
)
case ModelType.REGRESSION:
regression_data_quality = RegressionDataQuality(**data_quality_data)
return DataQualityDTO(
job_status=job_status,
data_quality=regression_data_quality,
)
case _:
raise MetricsInternalError(f'Invalid model type {model_type}')

data_quality = DataQualityDTO._create_data_quality(
model_type=model_type,
data_quality_data=data_quality_data,
)

return DataQualityDTO(
job_status=job_status,
data_quality=data_quality,
)

@staticmethod
def _create_data_quality(
model_type: ModelType,
data_quality_data: Dict,
) -> ClassificationDataQuality | RegressionDataQuality:
"""Create a specific data quality instance based on the model type."""
if model_type in {ModelType.BINARY, ModelType.MULTI_CLASS}:
return ClassificationDataQuality(**data_quality_data)
if model_type == ModelType.REGRESSION:
return RegressionDataQuality(**data_quality_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
52 changes: 28 additions & 24 deletions api/app/models/metrics/drift_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -45,7 +45,7 @@ class RegressionDrift(BaseModel):

class DriftDTO(BaseModel):
job_status: JobStatus
drift: Optional[Union[BinaryClassDrift, MultiClassDrift, RegressionDrift]]
drift: Optional[BinaryClassDrift | MultiClassDrift | RegressionDrift]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -58,30 +58,34 @@ def from_dict(
model_type: ModelType,
job_status: JobStatus,
drift_data: Optional[Dict],
):
) -> 'DriftDTO':
"""Create a DriftDTO from a dictionary of data."""
if not drift_data:
return DriftDTO(
job_status=job_status,
drift=None,
)
match model_type:
case ModelType.BINARY:
binary_class_data_quality = BinaryClassDrift(**drift_data)
return DriftDTO(
job_status=job_status,
drift=binary_class_data_quality,
)
case ModelType.MULTI_CLASS:
multi_class_data_quality = MultiClassDrift(**drift_data)
return DriftDTO(
job_status=job_status,
drift=multi_class_data_quality,
)
case ModelType.REGRESSION:
regression_data_quality = RegressionDrift(**drift_data)
return DriftDTO(
job_status=job_status,
drift=regression_data_quality,
)
case _:
raise MetricsInternalError(f'Invalid model type {model_type}')

drift = DriftDTO._create_drift(
model_type=model_type,
drift_data=drift_data,
)

return DriftDTO(
job_status=job_status,
drift=drift,
)

@staticmethod
def _create_drift(
model_type: ModelType,
drift_data: Dict,
) -> BinaryClassDrift | MultiClassDrift | RegressionDrift:
"""Create a specific drift instance based on the model type."""
if model_type == ModelType.BINARY:
return BinaryClassDrift(**drift_data)
if model_type == ModelType.MULTI_CLASS:
return MultiClassDrift(**drift_data)
if model_type == ModelType.REGRESSION:
return RegressionDrift(**drift_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
102 changes: 57 additions & 45 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -80,12 +80,10 @@ class RegressionModelQuality(BaseModel):
class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
Union[
BinaryClassModelQuality,
CurrentBinaryClassModelQuality,
MultiClassModelQuality,
RegressionModelQuality,
]
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
| RegressionModelQuality
]

model_config = ConfigDict(
Expand All @@ -94,46 +92,60 @@ class ModelQualityDTO(BaseModel):

@staticmethod
def from_dict(
dataset_type: DatasetType = DatasetType.REFERENCE,
model_type: ModelType = ModelType.BINARY,
job_status: JobStatus = JobStatus.SUCCEEDED,
model_quality_data: Optional[Dict] = None,
):
dataset_type: DatasetType,
model_type: ModelType,
job_status: JobStatus,
model_quality_data: Optional[Dict],
) -> 'ModelQualityDTO':
"""Create a ModelQualityDTO from a dictionary of data."""
if not model_quality_data:
return ModelQualityDTO(
job_status=job_status,
model_quality=None,
)
match model_type:
case ModelType.BINARY:
match dataset_type:
case DatasetType.REFERENCE:
binary_class_model_quality = BinaryClassModelQuality(
**model_quality_data
)
return ModelQualityDTO(
job_status=job_status,
model_quality=binary_class_model_quality,
)
case DatasetType.CURRENT:
current_binary_class_model_quality = (
CurrentBinaryClassModelQuality(**model_quality_data)
)
return ModelQualityDTO(
job_status=job_status,
model_quality=current_binary_class_model_quality,
)
case ModelType.MULTI_CLASS:
multi_class_model_quality = MultiClassModelQuality(**model_quality_data)
return ModelQualityDTO(
job_status=job_status,
model_quality=multi_class_model_quality,
)
case ModelType.REGRESSION:
regression_model_quality = RegressionModelQuality(**model_quality_data)
return ModelQualityDTO(
job_status=job_status,
model_quality=regression_model_quality,
)
case _:
raise MetricsInternalError(f'Invalid model type {model_type}')

model_quality = ModelQualityDTO._create_model_quality(
model_type=model_type,
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)

return ModelQualityDTO(
job_status=job_status,
model_quality=model_quality,
)

@staticmethod
def _create_model_quality(
model_type: ModelType,
dataset_type: DatasetType,
model_quality_data: Dict,
) -> (
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
| RegressionModelQuality
):
"""Create a specific model quality instance based on model type and dataset type."""
if model_type == ModelType.BINARY:
return ModelQualityDTO._create_binary_model_quality(
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)
if model_type == ModelType.MULTI_CLASS:
return MultiClassModelQuality(**model_quality_data)
if model_type == ModelType.REGRESSION:
return RegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid model type {model_type}')

@staticmethod
def _create_binary_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> BinaryClassModelQuality | CurrentBinaryClassModelQuality:
"""Create a binary model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return BinaryClassModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
18 changes: 12 additions & 6 deletions api/app/models/metrics/statistics_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@ class StatisticsDTO(BaseModel):
@staticmethod
def from_dict(
job_status: JobStatus, date: datetime, statistics_data: Optional[Dict]
):
if not statistics_data:
return StatisticsDTO(
job_status=job_status, statistics=None, date=date.isoformat()
)
statistics = Statistics(**statistics_data)
) -> 'StatisticsDTO':
"""Create a StatisticsDTO from a dictionary of data."""
statistics = StatisticsDTO._create_statistics(statistics_data)
return StatisticsDTO(
job_status=job_status,
statistics=statistics,
date=date.isoformat(),
)

@staticmethod
def _create_statistics(
statistics_data: Optional[Dict],
) -> Optional[Statistics]:
"""Create a Statistics instance from a dictionary of data."""
if not statistics_data:
return None
return Statistics(**statistics_data)
Loading