Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split datasets to train datasets and validation datasets. #490

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/algorithm_validation_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ jobs:
- name: Controller logs
run: cat /tmp/exareme2/controller.out

- name: Globalworker logs
run: cat /tmp/exareme2/globalworker.out

- name: Localworker logs
run: cat /tmp/exareme2/localworker1.out

Expand All @@ -115,6 +118,11 @@ jobs:
with:
run: cat /tmp/exareme2/controller.out

- name: Globalworker logs (post run)
uses: webiny/[email protected]
with:
run: cat /tmp/exareme2/globalworker.out

- name: Localworker logs (post run)
uses: webiny/[email protected]
with:
Expand Down
3 changes: 2 additions & 1 deletion exareme2/algorithms/flower/inputdata_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class Inputdata(BaseModel):
data_model: str
datasets: List[str]
validation_datasets: List[str]
filters: Optional[dict]
y: Optional[List[str]]
x: Optional[List[str]]
Expand All @@ -32,7 +33,7 @@ class Inputdata(BaseModel):
def apply_inputdata(df: pd.DataFrame, inputdata: Inputdata) -> pd.DataFrame:
if inputdata.filters:
df = apply_filter(df, inputdata.filters)
df = df[df["dataset"].isin(inputdata.datasets)]
df = df[df["dataset"].isin(inputdata.datasets + inputdata.validation_datasets)]
columns = inputdata.x + inputdata.y
df = df[columns]
df = df.dropna(subset=columns)
Expand Down
3 changes: 2 additions & 1 deletion exareme2/algorithms/flower/logistic_regression.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
],
"notblank": true,
"multiple": true
}
},
"validation": true
}
}
1 change: 1 addition & 0 deletions exareme2/algorithms/specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class InputDataSpecification(ImmutableBaseModel):
class InputDataSpecifications(ImmutableBaseModel):
y: InputDataSpecification
x: Optional[InputDataSpecification]
validation: Optional[bool]


class ParameterEnumSpecification(ImmutableBaseModel):
Expand Down
1 change: 1 addition & 0 deletions exareme2/controller/services/api/algorithm_request_dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Config:
class AlgorithmInputDataDTO(ImmutableBaseModel):
data_model: str
datasets: List[str]
validation_datasets: Optional[List[str]]
filters: Optional[dict]
y: Optional[List[str]]
x: Optional[List[str]]
Expand Down
93 changes: 69 additions & 24 deletions exareme2/controller/services/api/algorithm_request_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,21 @@ def validate_algorithm_request(
algorithm_name, algorithm_request_dto.type, algorithms_specs
)

available_datasets_per_data_model = (
worker_landscape_aggregator.get_all_available_datasets_per_data_model()
)

_validate_data_model(
requested_data_model=algorithm_request_dto.inputdata.data_model,
available_datasets_per_data_model=available_datasets_per_data_model,
(
training_datasets,
validation_datasets,
) = worker_landscape_aggregator.get_train_and_validation_datasets(
algorithm_request_dto.inputdata.data_model
)

data_model_cdes = worker_landscape_aggregator.get_cdes(
algorithm_request_dto.inputdata.data_model
)

_validate_algorithm_request_body(
algorithm_request_dto=algorithm_request_dto,
algorithm_specs=algorithm_specs,
transformers_specs=transformers_specs,
available_datasets_per_data_model=available_datasets_per_data_model,
training_datasets=training_datasets,
validation_datasets=validation_datasets,
data_model_cdes=data_model_cdes,
smpc_enabled=smpc_enabled,
smpc_optional=smpc_optional,
Expand All @@ -89,15 +86,21 @@ def _validate_algorithm_request_body(
algorithm_request_dto: AlgorithmRequestDTO,
algorithm_specs: AlgorithmSpecification,
transformers_specs: Dict[str, TransformerSpecification],
available_datasets_per_data_model: Dict[str, List[str]],
training_datasets: List[str],
validation_datasets: List[str],
data_model_cdes: Dict[str, CommonDataElement],
smpc_enabled: bool,
smpc_optional: bool,
):
_ensure_validation_criteria(
algorithm_request_dto.inputdata.validation_datasets,
algorithm_specs.inputdata.validation,
)
_validate_inputdata(
inputdata=algorithm_request_dto.inputdata,
inputdata_specs=algorithm_specs.inputdata,
available_datasets_per_data_model=available_datasets_per_data_model,
training_datasets=training_datasets,
validation_datasets=validation_datasets,
data_model_cdes=data_model_cdes,
)

Expand All @@ -122,45 +125,87 @@ def _validate_algorithm_request_body(
)


def _validate_data_model(requested_data_model: str, available_datasets_per_data_model):
if requested_data_model not in available_datasets_per_data_model.keys():
raise BadUserInput(f"Data model '{requested_data_model}' does not exist.")
def _ensure_validation_criteria(validation_datasets: List[str], validation: bool):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this context we don't know what validation is.

Rename to algorithm_specification_validation_flag ?

"""
Validates the input based on the provided validation flag and datasets.

Parameters:
validation_datasets (List[str]): List of validation datasets.
validation (bool): Flag indicating if validation is required.

Raises:
BadUserInput: If the input conditions are not met.
"""
if not validation and validation_datasets:
raise BadUserInput(
"Validation is false, but validation datasets were provided."
)
elif validation and not validation_datasets:
raise BadUserInput(
"Validation is true, but no validation datasets were provided."
)


def _validate_inputdata(
inputdata: AlgorithmInputDataDTO,
inputdata_specs: InputDataSpecifications,
available_datasets_per_data_model: Dict[str, List[str]],
training_datasets: List[str],
validation_datasets: List[str],
data_model_cdes: Dict[str, CommonDataElement],
):
_validate_inputdata_dataset(
_validate_inputdata_training_datasets(
requested_data_model=inputdata.data_model,
requested_datasets=inputdata.datasets,
available_datasets_per_data_model=available_datasets_per_data_model,
training_datasets=training_datasets,
)
_validate_inputdata_validation_datasets(
requested_data_model=inputdata.data_model,
requested_validation_datasets=inputdata.validation_datasets,
validation_datasets=validation_datasets,
)
_validate_inputdata_filter(inputdata.data_model, inputdata.filters, data_model_cdes)
_validate_algorithm_inputdatas(inputdata, inputdata_specs, data_model_cdes)


def _validate_inputdata_dataset(
def _validate_inputdata_training_datasets(
requested_data_model: str,
requested_datasets: List[str],
available_datasets_per_data_model: Dict[str, List[str]],
training_datasets: List[str],
):
"""
Validates that the dataset values exist and that the datasets belong in the data_model.
Validates that the dataset values exist and that the datasets.
"""
non_existing_datasets = [
dataset
for dataset in requested_datasets
if dataset not in available_datasets_per_data_model[requested_data_model]
dataset for dataset in requested_datasets if dataset not in training_datasets
]
if non_existing_datasets:
raise BadUserInput(
f"Datasets:'{non_existing_datasets}' could not be found for data_model:{requested_data_model}"
)


def _validate_inputdata_validation_datasets(
requested_data_model: str,
requested_validation_datasets: List[str],
validation_datasets: List[str],
):
"""
Validates that the validation dataset values exist and that the validation_datasets.
"""
if not requested_validation_datasets:
return

non_existing_datasets = [
dataset
for dataset in requested_validation_datasets
if dataset not in validation_datasets
]
if non_existing_datasets:
raise BadUserInput(
f"Validation Datasets:'{non_existing_datasets}' could not be found for data_model:{requested_data_model}"
)


def _validate_inputdata_filter(data_model, filter, data_model_cdes):
"""
Validates that the filter provided have the correct format
Expand Down
19 changes: 19 additions & 0 deletions exareme2/controller/services/api/algorithm_spec_dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class InputDataSpecificationsDTO(ImmutableBaseModel):
filter: InputDataSpecificationDTO
y: InputDataSpecificationDTO
x: Optional[InputDataSpecificationDTO]
validation_datasets: Optional[InputDataSpecificationDTO]


class ParameterEnumSpecificationDTO(ImmutableBaseModel):
Expand Down Expand Up @@ -121,6 +122,18 @@ def _get_data_model_input_data_specification_dto():
)


def _get_valiadtion_datasets_input_data_specification_dto():
return InputDataSpecificationDTO(
label="Set of data to validate.",
desc="The set of data to validate the algorithm model on.",
types=[InputDataType.TEXT],
notblank=True,
multiple=True,
stattypes=None,
enumslen=None,
)


def _get_datasets_input_data_specification_dto():
return InputDataSpecificationDTO(
label="Set of data to use.",
Expand Down Expand Up @@ -150,9 +163,15 @@ def _convert_inputdata_specifications_to_dto(spec: InputDataSpecifications):
# These parameters are not added by the algorithm developer.
y = _convert_inputdata_specification_to_dto(spec.y)
x = _convert_inputdata_specification_to_dto(spec.x) if spec.x else None
validation_datasets = (
_get_valiadtion_datasets_input_data_specification_dto()
if spec.validation
else None
)
return InputDataSpecificationsDTO(
y=y,
x=x,
validation_datasets=validation_datasets,
data_model=_get_data_model_input_data_specification_dto(),
datasets=_get_datasets_input_data_specification_dto(),
filter=_get_filters_input_data_specification_dto(),
Expand Down
16 changes: 11 additions & 5 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import warnings
from typing import Dict
from typing import List

Expand Down Expand Up @@ -51,16 +50,21 @@ def _create_worker_tasks_handler(self, request_id, worker_info: WorkerInfo):
)

async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
async with self.lock:
async with (self.lock):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
datasets = algorithm_request_dto.inputdata.datasets + (
algorithm_request_dto.inputdata.validation_datasets
if algorithm_request_dto.inputdata.validation_datasets
else []
)
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
algorithm_request_dto.inputdata.data_model, datasets
)

workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
Expand Down Expand Up @@ -93,7 +97,9 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
algorithm_name,
len(task_handlers),
str(server_address),
csv_paths_per_worker_id[server_id],
csv_paths_per_worker_id[server_id]
if algorithm_request_dto.inputdata.validation_datasets
else [],
)
clients_pids = {
handler.start_flower_client(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from exareme2.controller.workers_addresses import WorkersAddressesFactory
from exareme2.utils import AttrDict
from exareme2.worker_communication import BadUserInput
from exareme2.worker_communication import CommonDataElement
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
Expand Down Expand Up @@ -145,6 +146,7 @@ def get_csv_paths_per_worker_id(
].items()
if dataset in datasets
]

for dataset_info in dataset_infos:
if not dataset_info.csv_path:
raise DatasetMissingCsvPathError()
Expand Down Expand Up @@ -514,6 +516,37 @@ def get_cdes_per_data_model(self) -> DataModelsCDES:
def get_datasets_locations(self) -> DatasetsLocations:
return self._registries.data_model_registry.datasets_locations

def get_train_and_validation_datasets(
self, data_model: str
) -> Tuple[List[str], List[str]]:
"""
Retrieves all available training and validation datasets for a specific data model.

Parameters:
data_model (str): The data model for which to retrieve datasets.

Returns:
Tuple[List[str], List[str]]: A tuple containing two lists:
- The first list contains training datasets.
- The second list contains validation datasets.
"""
training_datasets = []
validation_datasets = []

if data_model not in self.get_datasets_locations().datasets_locations.keys():
raise BadUserInput(f"Data model '{data_model}' does not exist.")
datasets_locations = self.get_datasets_locations().datasets_locations[
data_model
]

for dataset, dataset_location in datasets_locations.items():
if dataset_location.worker_id == self.get_global_worker().id:
validation_datasets.append(dataset)
else:
training_datasets.append(dataset)

return training_datasets, validation_datasets

def get_all_available_datasets_per_data_model(self) -> Dict[str, List[str]]:
return (
self._registries.data_model_registry.get_all_available_datasets_per_data_model()
Expand Down Expand Up @@ -815,8 +848,8 @@ def _remove_incompatible_data_models_from_data_models_metadata_per_worker(
data_models_metadata_per_worker: DataModelsMetadataPerWorker
Returns
----------
List[str]
The incompatible data models
DataModelsMetadataPerWorker
The data_models_metadata_per_worker but with removed the incompatible data models
"""
validation_dictionary = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_logistic_regression(get_algorithm_result):
"ppmi7",
"ppmi8",
"ppmi9",
"ppmi_test",
],
"validation_datasets": ["ppmi_test"],
"filters": None,
},
"parameters": None,
Expand Down Expand Up @@ -44,8 +44,8 @@ def test_logistic_regression_with_filters(get_algorithm_result):
"ppmi7",
"ppmi8",
"ppmi9",
"ppmi_test",
],
"validation_datasets": ["ppmi_test"],
"filters": {
"condition": "AND",
"rules": [
Expand Down
Loading
Loading