diff --git a/jsonmodels/fields.py b/jsonmodels/fields.py index 722c833..aaf0c94 100644 --- a/jsonmodels/fields.py +++ b/jsonmodels/fields.py @@ -7,7 +7,7 @@ from .errors import ValidationError from .collections import ModelCollection - +from . import models # unique marker for "no default value specified". None is not good enough since # it is a completely valid default value. @@ -258,14 +258,14 @@ def _cast_value(self, value): if isinstance(value, self.items_types): return value else: - if len(self.items_types) != 1: - tpl = 'Cannot decide which type to choose from "{types}".' + embed_type = models.JsonmodelMeta.find_type(value, self.items_types) + if not embed_type: raise ValidationError( - tpl.format( + 'Cannot decide which type to choose from "{types}".' + .format( types=', '.join([t.__name__ for t in self.items_types]) - ) - ) - return self.items_types[0](**value) + )) + return embed_type(**value) def _finish_initialization(self, owner): super(ListField, self)._finish_initialization(owner) @@ -330,19 +330,17 @@ def parse_value(self, value): """Parse value to proper model type.""" if not isinstance(value, dict): return value - - embed_type = self._get_embed_type() + embed_type = models.JsonmodelMeta.find_type(value, self.types) + if not embed_type: + if self.nullable: + return None + else: + raise ValidationError( + 'Cannot decide which type to choose from "{types}".'.format( + types=', '.join([t.__name__ for t in self.types]) + )) return embed_type(**value) - def _get_embed_type(self): - if len(self.types) != 1: - raise ValidationError( - 'Cannot decide which type to choose from "{types}".'.format( - types=', '.join([t.__name__ for t in self.types]) - ) - ) - return self.types[0] - def to_struct(self, value): return value.to_struct() diff --git a/jsonmodels/models.py b/jsonmodels/models.py index da4bbac..1797f8e 100644 --- a/jsonmodels/models.py +++ b/jsonmodels/models.py @@ -24,6 +24,22 @@ def validate_fields(attributes): raise ValueError('Name taken', structue_name, name) taken_names.add(structue_name) + @staticmethod + def find_type(value, types): + matching = {} + for key, _ in value.items(): + if key.startswith('__'): + continue + + for typ in types: + matching.setdefault(typ, 0) + if key in dir(typ): + matching[typ] += 1 + + ordered = sorted(matching.items(), key=lambda kv: kv[1], reverse=True) + if not ordered: + return None + return ordered[0][0] class Base(six.with_metaclass(JsonmodelMeta, object)): diff --git a/tests/test_data_initialization.py b/tests/test_data_initialization.py index 7c4ff36..b534a60 100644 --- a/tests/test_data_initialization.py +++ b/tests/test_data_initialization.py @@ -67,7 +67,77 @@ class ParkingPlace(models.Base): assert car.brand == 'awesome brand' -def test_deep_initialization_error_with_multitypes(): +def test_deep_initialization_multiple_1(): + + class Car(models.Base): + + brand = fields.StringField() + + class Bus(models.Base): + + brand = fields.StringField() + seats = fields.IntField() + + class Train(models.Base): + + line = fields.StringField() + seats = fields.IntField() + + class ParkingPlace(models.Base): + + location = fields.StringField() + vehicle = fields.EmbeddedField([Car, Bus, Train]) + + data1 = { + 'location': 'somewhere', + 'vehicle': { + 'brand': 'awesome brand', + 'seats': 100 + } + } + + parking1 = ParkingPlace(**data1) + parking2 = ParkingPlace() + parking2.populate(**data1) + for parking in [parking1, parking2]: + assert parking.location == 'somewhere' + vehicle = parking.vehicle + assert isinstance(vehicle, Bus) + assert vehicle.brand == 'awesome brand' + assert vehicle.seats == 100 + + data2 = { + 'location': 'somewhere', + 'vehicle': { + 'line': 'Uptown', + 'seats': 400 + } + } + + parking1 = ParkingPlace(**data2) + parking2 = ParkingPlace() + parking2.populate(**data2) + for parking in [parking1, parking2]: + assert parking.location == 'somewhere' + vehicle = parking.vehicle + assert isinstance(vehicle, Train) + assert vehicle.line == 'Uptown' + assert vehicle.seats == 400 + + data3 = { + 'location': 'somewhere', + 'vehicle': { + } + } + + with pytest.raises(errors.ValidationError): + ParkingPlace(**data3) + + with pytest.raises(errors.ValidationError): + parking = ParkingPlace() + parking.populate(**data3) + +def test_deep_initialization_multiple_2(): class Viper(models.Base): @@ -89,12 +159,14 @@ class ParkingPlace(models.Base): } } - with pytest.raises(errors.ValidationError): - ParkingPlace(**data) - - place = ParkingPlace() - with pytest.raises(errors.ValidationError): - place.populate(**data) + parking1 = ParkingPlace(**data) + parking2 = ParkingPlace() + parking2.populate(**data) + for parking in [parking1, parking2]: + assert parking.location == 'somewhere' + car = parking.car + assert isinstance(car, Viper) + assert car.brand == 'awesome brand' def test_deep_initialization_with_list(): @@ -141,42 +213,101 @@ class Parking(models.Base): assert 'three' in values -def test_deep_initialization_error_with_list_and_multitypes(): +def test_deep_initialization_with_list_and_multitypes(): - class Viper(models.Base): + class Car(models.Base): brand = fields.StringField() + horsepower = fields.IntField() + owner = fields.StringField() - class Lamborghini(models.Base): + class Scooter(models.Base): brand = fields.StringField() + horsepower = fields.IntField() + speed = fields.IntField() class Parking(models.Base): location = fields.StringField() - cars = fields.ListField([Viper, Lamborghini]) + vehicle = fields.ListField([Car, Scooter]) data = { 'location': 'somewhere', - 'cars': [ + 'vehicle': [ { - 'brand': 'one', + 'brand': 'viper', + 'horsepower': 987, + 'owner': 'Jeff' }, { - 'brand': 'two', + 'brand': 'lamborgini', + 'horsepower': 877, }, { - 'brand': 'three', + 'brand': 'piaggio', + 'horsepower': 25, + 'speed': 120 }, ], } - with pytest.raises(errors.ValidationError): - Parking(**data) + parking1 = Parking(**data) + parking2 = Parking() + parking2.populate(**data) + for parking in [parking1, parking2]: + assert parking.location == 'somewhere' + vehicles = parking.vehicle + assert isinstance(vehicles, list) + assert len(vehicles) == 3 - parking = Parking() - with pytest.raises(errors.ValidationError): - parking.populate(**data) + assert isinstance(vehicles[0], Car) + assert vehicles[0].brand == 'viper' + assert vehicles[0].horsepower == 987 + assert vehicles[0].owner == 'Jeff' + + assert isinstance(vehicles[1], Car) + assert vehicles[1].brand == 'lamborgini' + assert vehicles[1].horsepower == 877 + assert vehicles[1].owner == None + + assert isinstance(vehicles[2], Scooter) + assert vehicles[2].brand == 'piaggio' + assert vehicles[2].horsepower == 25 + + +def test_deep_initialization_with_empty_list_and_multitypes(): + + class Car(models.Base): + + brand = fields.StringField() + horsepower = fields.IntField() + owner = fields.StringField() + + class Scooter(models.Base): + + brand = fields.StringField() + horsepower = fields.IntField() + speed = fields.IntField() + + class Parking(models.Base): + + location = fields.StringField() + vehicle = fields.ListField([Car, Scooter]) + + data = { + 'location': 'somewhere', + 'vehicle': [] + } + + parking1 = Parking(**data) + parking2 = Parking() + parking2.populate(**data) + for parking in [parking1, parking2]: + assert parking.location == 'somewhere' + vehicles = parking.vehicle + assert isinstance(vehicles, list) + assert len(vehicles) == 0 def test_deep_initialization_error_when_result_non_iterable():