Skip to content

Commit

Permalink
feat: refactoring of metrics part (#38)
Browse files Browse the repository at this point in the history
* feat: refactoring of metrics part

* fix: run ruff check
  • Loading branch information
dtria91 authored Jun 27, 2024
1 parent 0780b0e commit 9856b55
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 333 deletions.
60 changes: 28 additions & 32 deletions app/models/metrics/data_quality_dto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -102,10 +102,10 @@ class ClassMetrics(BaseModel):
)


class BinaryClassDataQuality(BaseModel):
class ClassificationDataQuality(BaseModel):
n_observations: int
class_metrics: List[ClassMetrics]
feature_metrics: List[Union[NumericalFeatureMetrics, CategoricalFeatureMetrics]]
feature_metrics: List[NumericalFeatureMetrics | CategoricalFeatureMetrics]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -115,19 +115,13 @@ class BinaryClassDataQuality(BaseModel):
)


class MultiClassDataQuality(BaseModel):
pass


class RegressionDataQuality(BaseModel):
pass


class DataQualityDTO(BaseModel):
job_status: JobStatus
data_quality: Optional[
Union[BinaryClassDataQuality, MultiClassDataQuality, RegressionDataQuality]
]
data_quality: Optional[ClassificationDataQuality | RegressionDataQuality]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -141,30 +135,32 @@ def from_dict(
model_type: ModelType,
job_status: JobStatus,
data_quality_data: Optional[Dict],
):
) -> 'DataQualityDTO':
"""Create a DataQualityDTO from a dictionary of data."""
if not data_quality_data:
return DataQualityDTO(
job_status=job_status,
data_quality=None,
)
match model_type:
case ModelType.BINARY:
binary_class_data_quality = BinaryClassDataQuality(**data_quality_data)
return DataQualityDTO(
job_status=job_status,
data_quality=binary_class_data_quality,
)
case ModelType.MULTI_CLASS:
multi_class_data_quality = MultiClassDataQuality(**data_quality_data)
return DataQualityDTO(
job_status=job_status,
data_quality=multi_class_data_quality,
)
case ModelType.REGRESSION:
regression_data_quality = RegressionDataQuality(**data_quality_data)
return DataQualityDTO(
job_status=job_status,
data_quality=regression_data_quality,
)
case _:
raise MetricsInternalError(f'Invalid model type {model_type}')

data_quality = DataQualityDTO._create_data_quality(
model_type=model_type,
data_quality_data=data_quality_data,
)

return DataQualityDTO(
job_status=job_status,
data_quality=data_quality,
)

@staticmethod
def _create_data_quality(
model_type: ModelType,
data_quality_data: Dict,
) -> ClassificationDataQuality | RegressionDataQuality:
"""Create a specific data quality instance based on the model type."""
if model_type in {ModelType.BINARY, ModelType.MULTI_CLASS}:
return ClassificationDataQuality(**data_quality_data)
if model_type == ModelType.REGRESSION:
return RegressionDataQuality(**data_quality_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
52 changes: 28 additions & 24 deletions app/models/metrics/drift_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -45,7 +45,7 @@ class RegressionDrift(BaseModel):

class DriftDTO(BaseModel):
job_status: JobStatus
drift: Optional[Union[BinaryClassDrift, MultiClassDrift, RegressionDrift]]
drift: Optional[BinaryClassDrift | MultiClassDrift | RegressionDrift]

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand All @@ -58,30 +58,34 @@ def from_dict(
model_type: ModelType,
job_status: JobStatus,
drift_data: Optional[Dict],
):
) -> 'DriftDTO':
"""Create a DriftDTO from a dictionary of data."""
if not drift_data:
return DriftDTO(
job_status=job_status,
drift=None,
)
match model_type:
case ModelType.BINARY:
binary_class_data_quality = BinaryClassDrift(**drift_data)
return DriftDTO(
job_status=job_status,
drift=binary_class_data_quality,
)
case ModelType.MULTI_CLASS:
multi_class_data_quality = MultiClassDrift(**drift_data)
return DriftDTO(
job_status=job_status,
drift=multi_class_data_quality,
)
case ModelType.REGRESSION:
regression_data_quality = RegressionDrift(**drift_data)
return DriftDTO(
job_status=job_status,
drift=regression_data_quality,
)
case _:
raise MetricsInternalError(f'Invalid model type {model_type}')

drift = DriftDTO._create_drift(
model_type=model_type,
drift_data=drift_data,
)

return DriftDTO(
job_status=job_status,
drift=drift,
)

@staticmethod
def _create_drift(
model_type: ModelType,
drift_data: Dict,
) -> BinaryClassDrift | MultiClassDrift | RegressionDrift:
"""Create a specific drift instance based on the model type."""
if model_type == ModelType.BINARY:
return BinaryClassDrift(**drift_data)
if model_type == ModelType.MULTI_CLASS:
return MultiClassDrift(**drift_data)
if model_type == ModelType.REGRESSION:
return RegressionDrift(**drift_data)
raise MetricsInternalError(f'Invalid model type {model_type}')
102 changes: 57 additions & 45 deletions app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -80,12 +80,10 @@ class RegressionModelQuality(BaseModel):
class ModelQualityDTO(BaseModel):
job_status: JobStatus
model_quality: Optional[
Union[
BinaryClassModelQuality,
CurrentBinaryClassModelQuality,
MultiClassModelQuality,
RegressionModelQuality,
]
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
| RegressionModelQuality
]

model_config = ConfigDict(
Expand All @@ -94,46 +92,60 @@ class ModelQualityDTO(BaseModel):

@staticmethod
def from_dict(
dataset_type: DatasetType = DatasetType.REFERENCE,
model_type: ModelType = ModelType.BINARY,
job_status: JobStatus = JobStatus.SUCCEEDED,
model_quality_data: Optional[Dict] = None,
):
dataset_type: DatasetType,
model_type: ModelType,
job_status: JobStatus,
model_quality_data: Optional[Dict],
) -> 'ModelQualityDTO':
"""Create a ModelQualityDTO from a dictionary of data."""
if not model_quality_data:
return ModelQualityDTO(
job_status=job_status,
model_quality=None,
)
match model_type:
case ModelType.BINARY:
match dataset_type:
case DatasetType.REFERENCE:
binary_class_model_quality = BinaryClassModelQuality(
**model_quality_data
)
return ModelQualityDTO(
job_status=job_status,
model_quality=binary_class_model_quality,
)
case DatasetType.CURRENT:
current_binary_class_model_quality = (
CurrentBinaryClassModelQuality(**model_quality_data)
)
return ModelQualityDTO(
job_status=job_status,
model_quality=current_binary_class_model_quality,
)
case ModelType.MULTI_CLASS:
multi_class_model_quality = MultiClassModelQuality(**model_quality_data)
return ModelQualityDTO(
job_status=job_status,
model_quality=multi_class_model_quality,
)
case ModelType.REGRESSION:
regression_model_quality = RegressionModelQuality(**model_quality_data)
return ModelQualityDTO(
job_status=job_status,
model_quality=regression_model_quality,
)
case _:
raise MetricsInternalError(f'Invalid model type {model_type}')

model_quality = ModelQualityDTO._create_model_quality(
model_type=model_type,
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)

return ModelQualityDTO(
job_status=job_status,
model_quality=model_quality,
)

@staticmethod
def _create_model_quality(
model_type: ModelType,
dataset_type: DatasetType,
model_quality_data: Dict,
) -> (
BinaryClassModelQuality
| CurrentBinaryClassModelQuality
| MultiClassModelQuality
| RegressionModelQuality
):
"""Create a specific model quality instance based on model type and dataset type."""
if model_type == ModelType.BINARY:
return ModelQualityDTO._create_binary_model_quality(
dataset_type=dataset_type,
model_quality_data=model_quality_data,
)
if model_type == ModelType.MULTI_CLASS:
return MultiClassModelQuality(**model_quality_data)
if model_type == ModelType.REGRESSION:
return RegressionModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid model type {model_type}')

@staticmethod
def _create_binary_model_quality(
dataset_type: DatasetType,
model_quality_data: Dict,
) -> BinaryClassModelQuality | CurrentBinaryClassModelQuality:
"""Create a binary model quality instance based on dataset type."""
if dataset_type == DatasetType.REFERENCE:
return BinaryClassModelQuality(**model_quality_data)
if dataset_type == DatasetType.CURRENT:
return CurrentBinaryClassModelQuality(**model_quality_data)
raise MetricsInternalError(f'Invalid dataset type {dataset_type}')
18 changes: 12 additions & 6 deletions app/models/metrics/statistics_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@ class StatisticsDTO(BaseModel):
@staticmethod
def from_dict(
job_status: JobStatus, date: datetime, statistics_data: Optional[Dict]
):
if not statistics_data:
return StatisticsDTO(
job_status=job_status, statistics=None, date=date.isoformat()
)
statistics = Statistics(**statistics_data)
) -> 'StatisticsDTO':
"""Create a StatisticsDTO from a dictionary of data."""
statistics = StatisticsDTO._create_statistics(statistics_data)
return StatisticsDTO(
job_status=job_status,
statistics=statistics,
date=date.isoformat(),
)

@staticmethod
def _create_statistics(
statistics_data: Optional[Dict],
) -> Optional[Statistics]:
"""Create a Statistics instance from a dictionary of data."""
if not statistics_data:
return None
return Statistics(**statistics_data)
Loading

0 comments on commit 9856b55

Please sign in to comment.