Skip to content

Commit

Permalink
feat(api): add field type as column definition property (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte authored Jul 24, 2024
1 parent 7eeedae commit 92342b9
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 46 deletions.
26 changes: 25 additions & 1 deletion api/app/models/inferred_schema_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pandas.core.arrays import boolean, floating, integer, string_
from pandas.core.dtypes import dtypes as pd_dtypes
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

from app.models.exceptions import UnsupportedSchemaException

Expand Down Expand Up @@ -34,11 +35,34 @@ def cast(value) -> 'SupportedTypes':
raise UnsupportedSchemaException(f'Unsupported type: {type(value)}')


class FieldType(str, Enum):
categorical = 'categorical'
numerical = 'numerical'
datetime = 'datetime'

@staticmethod
def from_supported_type(value: SupportedTypes) -> 'FieldType':
match value:
case SupportedTypes.datetime:
return FieldType.datetime
case SupportedTypes.int:
return FieldType.numerical
case SupportedTypes.float:
return FieldType.numerical
case SupportedTypes.bool:
return FieldType.categorical
case SupportedTypes.string:
return FieldType.categorical


class SchemaEntry(BaseModel):
name: str
type: SupportedTypes
field_type: FieldType

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(
arbitrary_types_allowed=True, populate_by_name=True, alias_generator=to_camel
)


class InferredSchemaDTO(BaseModel):
Expand Down
27 changes: 26 additions & 1 deletion api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.db.dao.current_dataset_dao import CurrentDataset
from app.db.dao.model_dao import Model
from app.db.dao.reference_dataset_dao import ReferenceDataset
from app.models.inferred_schema_dto import SupportedTypes
from app.models.inferred_schema_dto import FieldType, SupportedTypes
from app.models.job_status import JobStatus
from app.models.utils import is_none, is_number, is_number_or_string, is_optional_float

Expand All @@ -36,10 +36,35 @@ class Granularity(str, Enum):
class ColumnDefinition(BaseModel, validate_assignment=True):
name: str
type: SupportedTypes
field_type: FieldType

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

def to_dict(self):
return self.model_dump()

@model_validator(mode='after')
def validate_field_type(self) -> Self:
match (self.type, self.field_type):
case (SupportedTypes.datetime, FieldType.datetime):
return self
case (SupportedTypes.string, FieldType.categorical):
return self
case (SupportedTypes.bool, FieldType.categorical):
return self
case (SupportedTypes.int, FieldType.categorical):
return self
case (SupportedTypes.float, FieldType.categorical):
return self
case (SupportedTypes.int, FieldType.numerical):
return self
case (SupportedTypes.float, FieldType.numerical):
return self
case _:
raise ValueError(
f'column {self.name} with type {self.type} can not have filed type {self.field_type}'
)


class OutputType(BaseModel, validate_assignment=True):
prediction: ColumnDefinition
Expand Down
7 changes: 6 additions & 1 deletion api/app/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ModelNotFoundError,
)
from app.models.inferred_schema_dto import (
FieldType,
InferredSchemaDTO,
SchemaEntry,
SupportedTypes,
Expand Down Expand Up @@ -419,7 +420,11 @@ def schema_from_pandas(df: pd.DataFrame) -> InferredSchemaDTO:
data = data.loc[:, ~data.columns.str.contains('Unnamed')]
return InferredSchemaDTO(
inferred_schema=[
SchemaEntry(name=name.strip(), type=SupportedTypes.cast(type))
SchemaEntry(
name=name.strip(),
type=SupportedTypes.cast(type),
field_type=FieldType.from_supported_type(SupportedTypes.cast(type)),
)
for name, type in data.convert_dtypes(infer_objects=True).dtypes.items()
]
)
Expand Down
57 changes: 47 additions & 10 deletions api/tests/commons/csv_file_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from app.models.inferred_schema_dto import (
FieldType,
InferredSchemaDTO,
SchemaEntry,
SupportedTypes,
Expand All @@ -28,16 +29,52 @@ def get_dataframe_with_sep(sep: str) -> pd.DataFrame:

def correct_schema() -> InferredSchemaDTO:
schema = [
{'name': 'Name', 'type': SupportedTypes.string},
{'name': 'Age', 'type': SupportedTypes.int},
{'name': 'City', 'type': SupportedTypes.string},
{'name': 'Salary', 'type': SupportedTypes.float},
{'name': 'String', 'type': SupportedTypes.string},
{'name': 'Float/Int', 'type': SupportedTypes.float},
{'name': 'Boolean', 'type': SupportedTypes.bool},
{'name': 'Datetime', 'type': SupportedTypes.datetime},
{'name': 'Datetime2', 'type': SupportedTypes.datetime},
{'name': 'Datetime3', 'type': SupportedTypes.datetime},
{
'name': 'Name',
'type': SupportedTypes.string,
'fieldType': FieldType.categorical,
},
{'name': 'Age', 'type': SupportedTypes.int, 'fieldType': FieldType.numerical},
{
'name': 'City',
'type': SupportedTypes.string,
'fieldType': FieldType.categorical,
},
{
'name': 'Salary',
'type': SupportedTypes.float,
'fieldType': FieldType.numerical,
},
{
'name': 'String',
'type': SupportedTypes.string,
'fieldType': FieldType.categorical,
},
{
'name': 'Float/Int',
'type': SupportedTypes.float,
'fieldType': FieldType.numerical,
},
{
'name': 'Boolean',
'type': SupportedTypes.bool,
'fieldType': FieldType.categorical,
},
{
'name': 'Datetime',
'type': SupportedTypes.datetime,
'fieldType': FieldType.datetime,
},
{
'name': 'Datetime2',
'type': SupportedTypes.datetime,
'fieldType': FieldType.datetime,
},
{
'name': 'Datetime3',
'type': SupportedTypes.datetime,
'fieldType': FieldType.datetime,
},
]
schema = [SchemaEntry(**entry) for entry in schema]
return InferredSchemaDTO(inferred_schema=schema)
Expand Down
49 changes: 37 additions & 12 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from app.models.model_dto import (
ColumnDefinition,
DataType,
FieldType,
Granularity,
ModelIn,
ModelType,
Expand All @@ -31,14 +32,24 @@ def get_sample_model(
model_type: str = ModelType.BINARY.value,
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[Dict] = [{'name': 'feature1', 'type': 'string'}],
features: List[Dict] = [
{'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'}
],
outputs: Dict = {
'prediction': {'name': 'pred1', 'type': 'int'},
'prediction_proba': {'name': 'prob1', 'type': 'float'},
'output': [{'name': 'output1', 'type': 'string'}],
'prediction': {'name': 'pred1', 'type': 'int', 'fieldType': 'numerical'},
'prediction_proba': {
'name': 'prob1',
'type': 'float',
'fieldType': 'numerical',
},
'output': [{'name': 'output1', 'type': 'string', 'fieldType': 'categorical'}],
},
target: Dict = {'name': 'target1', 'type': 'string', 'fieldType': 'categorical'},
timestamp: Dict = {
'name': 'timestamp',
'type': 'datetime',
'fieldType': 'datetime',
},
target: Dict = {'name': 'target1', 'type': 'string'},
timestamp: Dict = {'name': 'timestamp', 'type': 'datetime'},
frameworks: Optional[str] = None,
algorithm: Optional[str] = None,
) -> Model:
Expand Down Expand Up @@ -68,18 +79,32 @@ def get_sample_model_in(
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[ColumnDefinition] = [
ColumnDefinition(name='feature1', type=SupportedTypes.string)
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
outputs: OutputType = OutputType(
prediction=ColumnDefinition(name='pred1', type=SupportedTypes.int),
prediction_proba=ColumnDefinition(name='prob1', type=SupportedTypes.float),
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
prediction=ColumnDefinition(
name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical
),
prediction_proba=ColumnDefinition(
name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical
),
output=[
ColumnDefinition(
name='output1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
),
target: ColumnDefinition = ColumnDefinition(
name='target1', type=SupportedTypes.int
name='target1', type=SupportedTypes.int, field_type=FieldType.numerical
),
timestamp: ColumnDefinition = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime
name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime
),
frameworks: Optional[str] = None,
algorithm: Optional[str] = None,
Expand Down
83 changes: 68 additions & 15 deletions api/tests/commons/modelin_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,106 @@
from app.models.model_dto import (
ColumnDefinition,
DataType,
FieldType,
Granularity,
ModelType,
OutputType,
)


def get_model_sample_wrong(fail_fields: List[str], model_type: ModelType):
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.int)
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float)
target = ColumnDefinition(name='target1', type=SupportedTypes.int)
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.datetime)
prediction = ColumnDefinition(
name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical
)
prediction_proba = ColumnDefinition(
name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical
)
target = ColumnDefinition(
name='target1', type=SupportedTypes.int, field_type=FieldType.numerical
)
timestamp = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime
)

if 'outputs.prediction' in fail_fields:
if model_type == ModelType.BINARY:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string)
prediction = ColumnDefinition(
name='pred1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
elif model_type == ModelType.MULTI_CLASS:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.datetime)
prediction = ColumnDefinition(
name='pred1',
type=SupportedTypes.datetime,
field_type=FieldType.datetime,
)
elif model_type == ModelType.REGRESSION:
prediction = ColumnDefinition(name='pred1', type=SupportedTypes.string)
prediction = ColumnDefinition(
name='pred1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)

if 'outputs.prediction_proba' in fail_fields:
if model_type in (ModelType.BINARY, ModelType.MULTI_CLASS):
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.int)
prediction_proba = ColumnDefinition(
name='prob1', type=SupportedTypes.int, field_type=FieldType.numerical
)
elif model_type == ModelType.REGRESSION:
prediction_proba = ColumnDefinition(name='prob1', type=SupportedTypes.float)
prediction_proba = ColumnDefinition(
name='prob1', type=SupportedTypes.float, field_type=FieldType.numerical
)

if 'target' in fail_fields:
if model_type == ModelType.BINARY:
target = ColumnDefinition(name='target1', type=SupportedTypes.string)
target = ColumnDefinition(
name='target1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
elif model_type == ModelType.MULTI_CLASS:
target = ColumnDefinition(name='target1', type=SupportedTypes.datetime)
target = ColumnDefinition(
name='target1',
type=SupportedTypes.datetime,
field_type=FieldType.datetime,
)
elif model_type == ModelType.REGRESSION:
target = ColumnDefinition(name='target1', type=SupportedTypes.string)
target = ColumnDefinition(
name='target1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)

if 'timestamp' in fail_fields:
timestamp = ColumnDefinition(name='timestamp', type=SupportedTypes.string)
timestamp = ColumnDefinition(
name='timestamp',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)

return {
'name': 'model_name',
'model_type': model_type,
'data_type': DataType.TEXT,
'granularity': Granularity.DAY,
'features': [ColumnDefinition(name='feature1', type=SupportedTypes.string)],
'features': [
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
'outputs': OutputType(
prediction=prediction,
prediction_proba=prediction_proba,
output=[ColumnDefinition(name='output1', type=SupportedTypes.string)],
output=[
ColumnDefinition(
name='output1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
),
'target': target,
'timestamp': timestamp,
Expand Down
Loading

0 comments on commit 92342b9

Please sign in to comment.