From 92342b9c4ad57b047174682d72442c4f12a4074e Mon Sep 17 00:00:00 2001 From: Mauro Cortellazzi Date: Wed, 24 Jul 2024 17:15:17 +0200 Subject: [PATCH] feat(api): add field type as column definition property (#133) --- api/app/models/inferred_schema_dto.py | 26 +++++++- api/app/models/model_dto.py | 27 +++++++- api/app/services/file_service.py | 7 ++- api/tests/commons/csv_file_mock.py | 57 ++++++++++++++--- api/tests/commons/db_mock.py | 49 +++++++++++---- api/tests/commons/modelin_factory.py | 83 ++++++++++++++++++++----- api/tests/services/file_service_test.py | 20 ++++-- 7 files changed, 223 insertions(+), 46 deletions(-) diff --git a/api/app/models/inferred_schema_dto.py b/api/app/models/inferred_schema_dto.py index 0f56cdeb..5f843c80 100644 --- a/api/app/models/inferred_schema_dto.py +++ b/api/app/models/inferred_schema_dto.py @@ -5,6 +5,7 @@ from pandas.core.arrays import boolean, floating, integer, string_ from pandas.core.dtypes import dtypes as pd_dtypes from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel from app.models.exceptions import UnsupportedSchemaException @@ -34,11 +35,34 @@ def cast(value) -> 'SupportedTypes': raise UnsupportedSchemaException(f'Unsupported type: {type(value)}') +class FieldType(str, Enum): + categorical = 'categorical' + numerical = 'numerical' + datetime = 'datetime' + + @staticmethod + def from_supported_type(value: SupportedTypes) -> 'FieldType': + match value: + case SupportedTypes.datetime: + return FieldType.datetime + case SupportedTypes.int: + return FieldType.numerical + case SupportedTypes.float: + return FieldType.numerical + case SupportedTypes.bool: + return FieldType.categorical + case SupportedTypes.string: + return FieldType.categorical + + class SchemaEntry(BaseModel): name: str type: SupportedTypes + field_type: FieldType - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict( + arbitrary_types_allowed=True, populate_by_name=True, alias_generator=to_camel + ) class InferredSchemaDTO(BaseModel): diff --git a/api/app/models/model_dto.py b/api/app/models/model_dto.py index fd4faa64..648f848b 100644 --- a/api/app/models/model_dto.py +++ b/api/app/models/model_dto.py @@ -9,7 +9,7 @@ from app.db.dao.current_dataset_dao import CurrentDataset from app.db.dao.model_dao import Model from app.db.dao.reference_dataset_dao import ReferenceDataset -from app.models.inferred_schema_dto import SupportedTypes +from app.models.inferred_schema_dto import FieldType, SupportedTypes from app.models.job_status import JobStatus from app.models.utils import is_none, is_number, is_number_or_string, is_optional_float @@ -36,10 +36,35 @@ class Granularity(str, Enum): class ColumnDefinition(BaseModel, validate_assignment=True): name: str type: SupportedTypes + field_type: FieldType + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) def to_dict(self): return self.model_dump() + @model_validator(mode='after') + def validate_field_type(self) -> Self: + match (self.type, self.field_type): + case (SupportedTypes.datetime, FieldType.datetime): + return self + case (SupportedTypes.string, FieldType.categorical): + return self + case (SupportedTypes.bool, FieldType.categorical): + return self + case (SupportedTypes.int, FieldType.categorical): + return self + case (SupportedTypes.float, FieldType.categorical): + return self + case (SupportedTypes.int, FieldType.numerical): + return self + case (SupportedTypes.float, FieldType.numerical): + return self + case _: + raise ValueError( + f'column {self.name} with type {self.type} can not have filed type {self.field_type}' + ) + class OutputType(BaseModel, validate_assignment=True): prediction: ColumnDefinition diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index 929b0bbb..254f3f2c 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -31,6 +31,7 @@ ModelNotFoundError, ) from app.models.inferred_schema_dto import ( + FieldType, InferredSchemaDTO, SchemaEntry, SupportedTypes, @@ -419,7 +420,11 @@ def schema_from_pandas(df: pd.DataFrame) -> InferredSchemaDTO: data = data.loc[:, ~data.columns.str.contains('Unnamed')] return InferredSchemaDTO( inferred_schema=[ - SchemaEntry(name=name.strip(), type=SupportedTypes.cast(type)) + SchemaEntry( + name=name.strip(), + type=SupportedTypes.cast(type), + field_type=FieldType.from_supported_type(SupportedTypes.cast(type)), + ) for name, type in data.convert_dtypes(infer_objects=True).dtypes.items() ] ) diff --git a/api/tests/commons/csv_file_mock.py b/api/tests/commons/csv_file_mock.py index 9fec928b..4bdce7e1 100644 --- a/api/tests/commons/csv_file_mock.py +++ b/api/tests/commons/csv_file_mock.py @@ -5,6 +5,7 @@ import pandas as pd from app.models.inferred_schema_dto import ( + FieldType, InferredSchemaDTO, SchemaEntry, SupportedTypes, @@ -28,16 +29,52 @@ def get_dataframe_with_sep(sep: str) -> pd.DataFrame: def correct_schema() -> InferredSchemaDTO: schema = [ - {'name': 'Name', 'type': SupportedTypes.string}, - {'name': 'Age', 'type': SupportedTypes.int}, - {'name': 'City', 'type': SupportedTypes.string}, - {'name': 'Salary', 'type': SupportedTypes.float}, - {'name': 'String', 'type': SupportedTypes.string}, - {'name': 'Float/Int', 'type': SupportedTypes.float}, - {'name': 'Boolean', 'type': SupportedTypes.bool}, - {'name': 'Datetime', 'type': SupportedTypes.datetime}, - {'name': 'Datetime2', 'type': SupportedTypes.datetime}, - {'name': 'Datetime3', 'type': SupportedTypes.datetime}, + { + 'name': 'Name', + 'type': SupportedTypes.string, + 'fieldType': FieldType.categorical, + }, + {'name': 'Age', 'type': SupportedTypes.int, 'fieldType': FieldType.numerical}, + { + 'name': 'City', + 'type': SupportedTypes.string, + 'fieldType': FieldType.categorical, + }, + { + 'name': 'Salary', + 'type': SupportedTypes.float, + 'fieldType': FieldType.numerical, + }, + { + 'name': 'String', + 'type': SupportedTypes.string, + 'fieldType': FieldType.categorical, + }, + { + 'name': 'Float/Int', + 'type': SupportedTypes.float, + 'fieldType': FieldType.numerical, + }, + { + 'name': 'Boolean', + 'type': SupportedTypes.bool, + 'fieldType': FieldType.categorical, + }, + { + 'name': 'Datetime', + 'type': SupportedTypes.datetime, + 'fieldType': FieldType.datetime, + }, + { + 'name': 'Datetime2', + 'type': SupportedTypes.datetime, + 'fieldType': FieldType.datetime, + }, + { + 'name': 'Datetime3', + 'type': SupportedTypes.datetime, + 'fieldType': FieldType.datetime, + }, ] schema = [SchemaEntry(**entry) for entry in schema] return InferredSchemaDTO(inferred_schema=schema) diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index 2a488446..3a034fde 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -11,6 +11,7 @@ from app.models.model_dto import ( ColumnDefinition, DataType, + FieldType, Granularity, ModelIn, ModelType, @@ -31,14 +32,24 @@ def get_sample_model( model_type: str = ModelType.BINARY.value, data_type: str = DataType.TEXT.value, granularity: str = Granularity.DAY.value, - features: List[Dict] = [{'name': 'feature1', 'type': 'string'}], + features: List[Dict] = [ + {'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'} + ], outputs: Dict = { - 'prediction': {'name': 'pred1', 'type': 'int'}, - 'prediction_proba': {'name': 'prob1', 'type': 'float'}, - 'output': [{'name': 'output1', 'type': 'string'}], + 'prediction': {'name': 'pred1', 'type': 'int', 'fieldType': 'numerical'}, + 'prediction_proba': { + 'name': 'prob1', + 'type': 'float', + 'fieldType': 'numerical', + }, + 'output': [{'name': 'output1', 'type': 'string', 'fieldType': 'categorical'}], + }, + target: Dict = {'name': 'target1', 'type': 'string', 'fieldType': 'categorical'}, + timestamp: Dict = { + 'name': 'timestamp', + 'type': 'datetime', + 'fieldType': 'datetime', }, - target: Dict = {'name': 'target1', 'type': 'string'}, - timestamp: Dict = {'name': 'timestamp', 'type': 'datetime'}, frameworks: Optional[str] = None, algorithm: Optional[str] = None, ) -> Model: @@ -68,18 +79,32 @@ def get_sample_model_in( data_type: str = DataType.TEXT.value, granularity: str = Granularity.DAY.value, features: List[ColumnDefinition] = [ - ColumnDefinition(name='feature1', type=SupportedTypes.string) + ColumnDefinition( + name='feature1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) ], outputs: OutputType = OutputType( - prediction=ColumnDefinition(name='pred1', type=SupportedTypes.int), - prediction_proba=ColumnDefinition(name='prob1', type=SupportedTypes.float), - output=[ColumnDefinition(name='output1', type=SupportedTypes.string)], + prediction=ColumnDefinition( + name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical + ), + prediction_proba=ColumnDefinition( + name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical + ), + output=[ + ColumnDefinition( + name='output1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], ), target: ColumnDefinition = ColumnDefinition( - name='target1', type=SupportedTypes.int + name='target1', type=SupportedTypes.int, field_type=FieldType.numerical ), timestamp: ColumnDefinition = ColumnDefinition( - name='timestamp', type=SupportedTypes.datetime + name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime ), frameworks: Optional[str] = None, algorithm: Optional[str] = None, diff --git a/api/tests/commons/modelin_factory.py b/api/tests/commons/modelin_factory.py index 2417b231..f2fc7091 100644 --- a/api/tests/commons/modelin_factory.py +++ b/api/tests/commons/modelin_factory.py @@ -4,6 +4,7 @@ from app.models.model_dto import ( ColumnDefinition, DataType, + FieldType, Granularity, ModelType, OutputType, @@ -11,46 +12,98 @@ def get_model_sample_wrong(fail_fields: List[str], model_type: ModelType): - prediction = ColumnDefinition(name='pred1', type=SupportedTypes.int) - prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float) - target = ColumnDefinition(name='target1', type=SupportedTypes.int) - timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.datetime) + prediction = ColumnDefinition( + name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical + ) + prediction_proba = ColumnDefinition( + name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical + ) + target = ColumnDefinition( + name='target1', type=SupportedTypes.int, field_type=FieldType.numerical + ) + timestamp = ColumnDefinition( + name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime + ) if 'outputs.prediction' in fail_fields: if model_type == ModelType.BINARY: - prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string) + prediction = ColumnDefinition( + name='pred1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) elif model_type == ModelType.MULTI_CLASS: - prediction = ColumnDefinition(name='pred1', type=SupportedTypes.datetime) + prediction = ColumnDefinition( + name='pred1', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ) elif model_type == ModelType.REGRESSION: - prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string) + prediction = ColumnDefinition( + name='pred1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) if 'outputs.prediction_proba' in fail_fields: if model_type in (ModelType.BINARY, ModelType.MULTI_CLASS): - prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.int) + prediction_proba = ColumnDefinition( + name='prob1', type=SupportedTypes.int, field_type=FieldType.numerical + ) elif model_type == ModelType.REGRESSION: - prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float) + prediction_proba = ColumnDefinition( + name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical + ) if 'target' in fail_fields: if model_type == ModelType.BINARY: - target = ColumnDefinition(name='target1', type=SupportedTypes.string) + target = ColumnDefinition( + name='target1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) elif model_type == ModelType.MULTI_CLASS: - target = ColumnDefinition(name='target1', type=SupportedTypes.datetime) + target = ColumnDefinition( + name='target1', + type=SupportedTypes.datetime, + field_type=FieldType.datetime, + ) elif model_type == ModelType.REGRESSION: - target = ColumnDefinition(name='target1', type=SupportedTypes.string) + target = ColumnDefinition( + name='target1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) if 'timestamp' in fail_fields: - timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.string) + timestamp = ColumnDefinition( + name='timestamp', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) return { 'name': 'model_name', 'model_type': model_type, 'data_type': DataType.TEXT, 'granularity': Granularity.DAY, - 'features': [ColumnDefinition(name='feature1', type=SupportedTypes.string)], + 'features': [ + ColumnDefinition( + name='feature1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], 'outputs': OutputType( prediction=prediction, prediction_proba=prediction_proba, - output=[ColumnDefinition(name='output1', type=SupportedTypes.string)], + output=[ + ColumnDefinition( + name='output1', + type=SupportedTypes.string, + field_type=FieldType.categorical, + ) + ], ), 'target': target, 'timestamp': timestamp, diff --git a/api/tests/services/file_service_test.py b/api/tests/services/file_service_test.py index 92bf4fbc..96231650 100644 --- a/api/tests/services/file_service_test.py +++ b/api/tests/services/file_service_test.py @@ -188,14 +188,22 @@ def test_bind_reference_file_already_exists(self): def test_upload_current_file_ok(self): file = csv.get_current_sample_csv_file() model = db_mock.get_sample_model( - features=[{'name': 'num1', 'type': 'int'}], + features=[{'name': 'num1', 'type': 'int', 'fieldType': 'numerical'}], outputs={ - 'prediction': {'name': 'prediction', 'type': 'int'}, - 'prediction_proba': {'name': 'prediction_proba', 'type': 'int'}, - 'output': [{'name': 'num2', 'type': 'int'}], + 'prediction': { + 'name': 'prediction', + 'type': 'int', + 'fieldType': 'numerical', + }, + 'prediction_proba': { + 'name': 'prediction_proba', + 'type': 'int', + 'fieldType': 'numerical', + }, + 'output': [{'name': 'num2', 'type': 'int', 'fieldType': 'numerical'}], }, - target={'name': 'target', 'type': 'int'}, - timestamp={'name': 'datetime', 'type': 'datetime'}, + target={'name': 'target', 'type': 'int', 'fieldType': 'numerical'}, + timestamp={'name': 'datetime', 'type': 'datetime', 'fieldType': 'datetime'}, ) object_name = f'{str(model.uuid)}/current/{file.filename}' path = f's3://bucket/{object_name}'