Skip to content

Commit

Permalink
Added new cerberus validator, changed instrument_type and configurati…
Browse files Browse the repository at this point in the history
…on_type validation_schemas to be applied at the Configuration level instead of Instrument Config, and changed how those validation_schemas errors are reported so they will mimic a DRF validation error on the bad field
  • Loading branch information
Jon committed Sep 28, 2024
1 parent e0dcc01 commit 790010e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
13 changes: 13 additions & 0 deletions observation_portal/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
from functools import wraps
from django.core.serializers.json import DjangoJSONEncoder
from django.core.cache import caches
from cerberus import Validator


class OCSValidator(Validator):
""" Custom validator that allows label, show(in UI), and description fields in the schema """
def _validate_description(self, constraint, field, value):
pass

def _validate_label(self, constraint, field, value):
pass

def _validate_show(self, constraint, field, value):
pass


def get_queryset_field_values(queryset, field):
Expand Down
58 changes: 42 additions & 16 deletions observation_portal/requestgroups/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from observation_portal.requestgroups.target_helpers import TARGET_TYPE_HELPER_MAP
from observation_portal.common.mixins import ExtraParamsFormatter
from observation_portal.common.configdb import configdb, ConfigDB
from observation_portal.common.utils import OCSValidator
from observation_portal.requestgroups.duration_utils import (
get_total_request_duration, get_requestgroup_duration, get_total_duration_dict,
get_instrument_configuration_duration, get_semester_in
Expand All @@ -45,14 +46,14 @@ def __init__(self):
def validate(self, config_dict: dict) -> dict:
pass

def _validate_document(self, document: dict, validation_schema: dict) -> (Validator, dict):
def _validate_document(self, document: dict, validation_schema: dict) -> (OCSValidator, dict):
"""
Perform validation on a document using Cerberus validation schema
:param document: Document to be validated
:param validation_schema: Cerberus validation schema
:return: Tuple of validator and a validated document
"""
validator = Validator(validation_schema)
validator = OCSValidator(validation_schema)
validator.allow_unknown = True
validated_config_dict = validator.validated(document) or document.copy()

Expand All @@ -74,6 +75,30 @@ def _cerberus_validation_error_to_str(self, validation_errors: dict) -> str:
error_str = error_str.rstrip(', ')
return error_str

def _cerberus_to_serializer_validation_error(self, validation_errors: dict) -> dict:
"""
Unpack and format Cerberus validation errors as a dict matching the DRF Serializer Validation error format
:param validation_errors: Errors from the validator (validator.errors)
:return: Dict containing DRF serializer validation error for the cerberus errors
"""
# The two issues we have are with extra_params becoming a list, and instrument_configs not having their index work properly
serializer_errors = {}
if 'extra_params' in validation_errors:
serializer_errors['extra_params'] = validation_errors['extra_params'][0]
if 'instrument_configs' in validation_errors:
instrument_configs_errors = []
last_instrument_config_with_error = max(validation_errors['instrument_configs'][0].keys())
for i in range(0, last_instrument_config_with_error+1):
if i in validation_errors['instrument_configs'][0]:
instrument_config_error = validation_errors['instrument_configs'][0][i][0].copy()
if 'extra_params' in instrument_config_error:
instrument_config_error['extra_params'] = instrument_config_error['extra_params'][0]
instrument_configs_errors.append(instrument_config_error)
else:
instrument_configs_errors.append({})
serializer_errors['instrument_configs'] = instrument_configs_errors
return serializer_errors


class InstrumentTypeValidationHelper(ValidationHelper):
"""Class to validate config based on InstrumentType in ConfigDB"""
Expand All @@ -91,9 +116,7 @@ def validate(self, config_dict: dict) -> dict:
validation_schema = instrument_type_dict.get('validation_schema', {})
validator, validated_config_dict = self._validate_document(config_dict, validation_schema)
if validator.errors:
raise serializers.ValidationError(_(
f'Invalid configuration: {self._cerberus_validation_error_to_str(validator.errors)}'
))
raise serializers.ValidationError(self._cerberus_to_serializer_validation_error(validator.errors))

return validated_config_dict

Expand Down Expand Up @@ -187,13 +210,16 @@ def __init__(self, instrument_type: str, configuration_type: str):
self._configuration_type = configuration_type

def validate(self, config_dict: dict) -> dict:
configuration_type_properties = configdb.get_configuration_types(self._instrument_type)[self._configuration_type]
configuration_types = configdb.get_configuration_types(self._instrument_type)
if self._configuration_type not in configuration_types:
raise serializers.ValidationError(_(
f'configuration type {self._configuration_type} is not valid for instrument type {self._instrument_type}'
))
configuration_type_properties = configuration_types[self._configuration_type]
validation_schema = configuration_type_properties.get('validation_schema', {})
validator, validated_config_dict = self._validate_document(config_dict, validation_schema)
if validator.errors:
raise serializers.ValidationError(_(
f'Invalid configuration: {self._cerberus_validation_error_to_str(validator.errors)}'
))
raise serializers.ValidationError(self._cerberus_to_serializer_validation_error(validator.errors))

return validated_config_dict

Expand Down Expand Up @@ -374,19 +400,19 @@ def validate(self, data):
acquisition_config['mode'] = AcquisitionConfig.OFF
data['acquisition_config'] = acquisition_config

# Validate the instrument_type and configuration_type properties related validation schema at the configuration level
instrument_type_validation_helper = InstrumentTypeValidationHelper(instrument_type)
instrument_config = instrument_type_validation_helper.validate(data)

configuration_type_validation_helper = ConfigurationTypeValidationHelper(instrument_type, data['type'])
instrument_config = configuration_type_validation_helper.validate(data)

available_optical_elements = configdb.get_optical_elements(instrument_type)
for i, instrument_config in enumerate(data['instrument_configs']):
# Validate the named readout mode if set, or set the default readout mode if left blank
readout_mode = instrument_config.get('mode', '')
readout_validation_helper = ModeValidationHelper('readout', instrument_type, modes['readout'])
instrument_config = readout_validation_helper.validate(instrument_config)

instrument_type_validation_helper = InstrumentTypeValidationHelper(instrument_type)
instrument_config = instrument_type_validation_helper.validate(instrument_config)

configuration_type_validation_helper = ConfigurationTypeValidationHelper(instrument_type, data['type'])
instrument_config = configuration_type_validation_helper.validate(instrument_config)

data['instrument_configs'][i] = instrument_config

# Validate the rotator modes
Expand Down
8 changes: 5 additions & 3 deletions observation_portal/requestgroups/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,15 @@ def get(self, request):
for instrument_type in configdb.get_instrument_type_codes(location=location, only_schedulable=only_schedulable):
if not requested_instrument_type or requested_instrument_type.upper() == instrument_type.upper():
ccd_size = configdb.get_ccd_size(instrument_type)
instrument_type_dict = configdb.get_instrument_type_by_code(instrument_type)
info[instrument_type] = {
'type': configdb.get_instrument_type_category(instrument_type),
'type': instrument_type_dict.get('instrument_category', 'None'),
'class': configdb.get_instrument_type_telescope_class(instrument_type),
'name': configdb.get_instrument_type_full_name(instrument_type),
'name': instrument_type_dict.get('name', instrument_type),
'optical_elements': configdb.get_optical_elements(instrument_type),
'modes': configdb.get_modes_by_type(instrument_type),
'default_acceptability_threshold': configdb.get_default_acceptability_threshold(instrument_type),
'validation_schema': instrument_type_dict.get('validation_schema', {}),
'default_acceptability_threshold': instrument_type_dict.get('default_acceptability_threshold'),
'configuration_types': configdb.get_configuration_types(instrument_type),
'default_configuration_type': configdb.get_default_configuration_type(instrument_type),
'camera_type': {
Expand Down

0 comments on commit 790010e

Please sign in to comment.