Skip to content

Commit

Permalink
Improve algorithm validation messages. (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
KFilippopolitis authored Jul 24, 2024
1 parent fcb3188 commit 878262f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 40 deletions.
52 changes: 21 additions & 31 deletions exareme2/controller/services/api/algorithm_request_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def validate_algorithm_request(
(
training_datasets,
validation_datasets,
) = worker_landscape_aggregator.get_train_and_validation_datasets(
) = worker_landscape_aggregator.get_training_and_validation_datasets(
algorithm_request_dto.inputdata.data_model
)
data_model_cdes = worker_landscape_aggregator.get_cdes(
Expand Down Expand Up @@ -92,14 +92,11 @@ def _validate_algorithm_request_body(
smpc_enabled: bool,
smpc_optional: bool,
):
_ensure_validation_criteria(
algorithm_request_dto.inputdata.validation_datasets,
algorithm_specs.inputdata.validation,
)
_validate_inputdata(
inputdata=algorithm_request_dto.inputdata,
inputdata_specs=algorithm_specs.inputdata,
training_datasets=training_datasets,
algorithm_specification_validation_flag=algorithm_specs.inputdata.validation,
validation_datasets=validation_datasets,
data_model_cdes=data_model_cdes,
)
Expand All @@ -125,42 +122,23 @@ def _validate_algorithm_request_body(
)


def _ensure_validation_criteria(validation_datasets: List[str], validation: bool):
"""
Validates the input based on the provided validation flag and datasets.
Parameters:
validation_datasets (List[str]): List of validation datasets.
validation (bool): Flag indicating if validation is required.
Raises:
BadUserInput: If the input conditions are not met.
"""
if not validation and validation_datasets:
raise BadUserInput(
"Validation is false, but validation datasets were provided."
)
elif validation and not validation_datasets:
raise BadUserInput(
"Validation is true, but no validation datasets were provided."
)


def _validate_inputdata(
inputdata: AlgorithmInputDataDTO,
inputdata_specs: InputDataSpecifications,
training_datasets: List[str],
algorithm_specification_validation_flag: Optional[bool],
validation_datasets: List[str],
data_model_cdes: Dict[str, CommonDataElement],
):
_validate_inputdata_training_datasets(
requested_data_model=inputdata.data_model,
requested_datasets=inputdata.datasets,
requested_training_datasets=inputdata.datasets,
training_datasets=training_datasets,
)
_validate_inputdata_validation_datasets(
requested_data_model=inputdata.data_model,
requested_validation_datasets=inputdata.validation_datasets,
algorithm_specification_validation_flag=algorithm_specification_validation_flag,
validation_datasets=validation_datasets,
)
_validate_inputdata_filter(inputdata.data_model, inputdata.filters, data_model_cdes)
Expand All @@ -169,14 +147,16 @@ def _validate_inputdata(

def _validate_inputdata_training_datasets(
requested_data_model: str,
requested_datasets: List[str],
requested_training_datasets: List[str],
training_datasets: List[str],
):
"""
Validates that the dataset values exist and that the datasets.
Validates that the dataset values exist.
"""
non_existing_datasets = [
dataset for dataset in requested_datasets if dataset not in training_datasets
dataset
for dataset in requested_training_datasets
if dataset not in training_datasets
]
if non_existing_datasets:
raise BadUserInput(
Expand All @@ -187,11 +167,21 @@ def _validate_inputdata_training_datasets(
def _validate_inputdata_validation_datasets(
requested_data_model: str,
requested_validation_datasets: List[str],
algorithm_specification_validation_flag,
validation_datasets: List[str],
):
"""
Validates that the validation dataset values exist and that the validation_datasets.
Validates that the validation dataset values exist.
"""
if not algorithm_specification_validation_flag and requested_validation_datasets:
raise BadUserInput(
"The algorithm does not have a validation flow, but 'validation_datasets' were provided in the 'inputdata'."
)
elif algorithm_specification_validation_flag and not requested_validation_datasets:
raise BadUserInput(
"The algorithm requires 'validation_datasets', in the 'inputdata', but none were provided."
)

if not requested_validation_datasets:
return

Expand Down
8 changes: 4 additions & 4 deletions exareme2/controller/services/api/algorithm_spec_dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _get_data_model_input_data_specification_dto():
)


def _get_valiadtion_datasets_input_data_specification_dto():
def _get_validation_datasets_input_data_specification_dto():
return InputDataSpecificationDTO(
label="Set of data to validate.",
desc="The set of data to validate the algorithm model on.",
Expand Down Expand Up @@ -163,15 +163,15 @@ def _convert_inputdata_specifications_to_dto(spec: InputDataSpecifications):
# These parameters are not added by the algorithm developer.
y = _convert_inputdata_specification_to_dto(spec.y)
x = _convert_inputdata_specification_to_dto(spec.x) if spec.x else None
validation_datasets = (
_get_valiadtion_datasets_input_data_specification_dto()
validation_datasets_dto = (
_get_validation_datasets_input_data_specification_dto()
if spec.validation
else None
)
return InputDataSpecificationsDTO(
y=y,
x=x,
validation_datasets=validation_datasets,
validation_datasets=validation_datasets_dto,
data_model=_get_data_model_input_data_specification_dto(),
datasets=_get_datasets_input_data_specification_dto(),
filter=_get_filters_input_data_specification_dto(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def get_cdes_per_data_model(self) -> DataModelsCDES:
def get_datasets_locations(self) -> DatasetsLocations:
return self._registries.data_model_registry.datasets_locations

def get_train_and_validation_datasets(
def get_training_and_validation_datasets(
self, data_model: str
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -527,8 +527,8 @@ def get_train_and_validation_datasets(
Returns:
Tuple[List[str], List[str]]: A tuple containing two lists:
- The first list contains training datasets.
- The second list contains validation datasets.
- The first list contains training datasets (Data loaded inside the localworkers).
- The second list contains validation datasets (Data loaded inside the globalworker).
"""
training_datasets = []
validation_datasets = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ def get_parametrization_list_exception_cases():
),
(
BadUserInput,
"Validation is false, but validation datasets were provided.",
"The algorithm does not have a validation flow, but 'validation_datasets' were provided in the 'inputdata'.",
),
id="Validation datasets on algorithm without validation",
),
Expand All @@ -1428,7 +1428,7 @@ def get_parametrization_list_exception_cases():
),
(
BadUserInput,
"Validation is true, but no validation datasets were provided.",
"The algorithm requires 'validation_datasets', in the 'inputdata', but none were provided.",
),
id="Missing validation datasets on algorithm validation",
),
Expand Down

0 comments on commit 878262f

Please sign in to comment.