From 719bf11b14a80ad035670e7b73368433196635d2 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 20 Apr 2015 12:17:35 +0300 Subject: [PATCH 01/39] Update iterables on regular update --- nefertari_sqla/documents.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index b009875..49be1f9 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -345,12 +345,21 @@ def get_or_create(cls, **params): def _update(self, params, **kw): process_bools(params) self.check_fields_allowed(params.keys()) + fields = {c.name: c for c in class_mapper(self.__class__).columns} + iter_fields = set( + k for k, v in fields.items() + if isinstance(v, (DictField, ListField))) id_field = self.id_field() + for key, value in params.items(): # Can't change PK field if key == id_field: continue - setattr(self, key, value) + if key in iter_fields: + self.update_iterables(value, key, unique=True, save=False) + else: + setattr(self, key, value) + session = object_session(self) session.add(self) session.flush() @@ -409,7 +418,8 @@ def to_dict(self, **kwargs): _dict['id'] = getattr(self, self.id_field()) return _dict - def update_iterables(self, params, attr, unique=False, value_type=None): + def update_iterables(self, params, attr, unique=False, + value_type=None, save=True): mapper = class_mapper(self.__class__) fields = {c.name: c for c in mapper.columns} is_dict = isinstance(fields.get(attr), DictField) @@ -439,12 +449,18 @@ def update_dict(): # Set positive keys for key in positive: final_value[unicode(key)] = params[key] - self.update({attr: final_value}) + + setattr(self, attr, final_value) + if save: + session = object_session(self) + session.add(self) + session.flush() def update_list(): final_value = getattr(self, attr, []) or [] final_value = copy.deepcopy(final_value) - positive, negative = split_keys(params.keys()) + keys = params.keys() if isinstance(params, dict) else params + positive, negative = split_keys(keys) if not (positive + negative): raise JHTTPBadRequest('Missing params') @@ -457,7 +473,11 @@ def update_list(): if negative: final_value = list(set(final_value) - set(negative)) - self.update({attr: final_value}) + setattr(self, attr, final_value) + if save: + session = object_session(self) + session.add(self) + session.flush() if is_dict: update_dict() From 7fa596a1eb70737507b5e4ce588bb28673da519f Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 20 Apr 2015 16:18:58 +0300 Subject: [PATCH 02/39] Implement document `autogenerate_for` --- nefertari_sqla/documents.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index b009875..59d14fb 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -64,6 +64,15 @@ class BaseMixin(object): _type = property(lambda self: self.__class__.__name__) + @classmethod + def autogenerate_for(cls, model, set_to): + from sqlalchemy import event + + def generate(mapper, connection, target): + cls(**{set_to: target}) + + event.listen(model, 'after_insert', generate) + @classmethod def id_field(cls): """ Get a primary key field name. """ From ac48f45544617eda95653b6cdbbc734708a9ca60 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 20 Apr 2015 16:51:57 +0300 Subject: [PATCH 03/39] Add docstrings for autogenerate_for --- nefertari_sqla/documents.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 59d14fb..c4bcb11 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -66,6 +66,12 @@ class BaseMixin(object): @classmethod def autogenerate_for(cls, model, set_to): + """ Setup `after_insert` event handler. + + Event handler is registered for class :model: and creates a new + instance of :cls: with a field :set_to: set to an instance on + which event occured. + """ from sqlalchemy import event def generate(mapper, connection, target): From 73dccfd76df40ca1b41c1900706bfaf3f9886832 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Fri, 24 Apr 2015 14:13:03 +0300 Subject: [PATCH 04/39] Rename privacy model fields --- nefertari_sqla/documents.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index b009875..d71d826 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -48,8 +48,8 @@ class BaseMixin(object): Attributes: _auth_fields: String names of fields meant to be displayed to authenticated users. - _public_fields: String names of fields meant to be displayed to - NOT authenticated users. + _hidden_fields: String names of fields meant to be displayed to + admin only _nested_fields: ? _nested_relationships: String names of relationship fields that should be included in JSON data of an object as full @@ -57,8 +57,8 @@ class BaseMixin(object): present in this list, this field's value in JSON will be an object's ID or list of IDs. """ + _hidden_fields = None _auth_fields = None - _public_fields = None _nested_fields = None _nested_relationships = () From 5d6283239be5eaed67cf0a3db6103c5ffb99c778 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Fri, 24 Apr 2015 16:43:42 +0300 Subject: [PATCH 05/39] Use default values for Dict, List field --- nefertari_sqla/fields.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nefertari_sqla/fields.py b/nefertari_sqla/fields.py index 0a88087..342bb31 100644 --- a/nefertari_sqla/fields.py +++ b/nefertari_sqla/fields.py @@ -246,6 +246,12 @@ class DictField(BaseField): _sqla_generic_type = ProcessableDict _type_unchanged_kwargs = () + def process_type_args(self, kwargs): + type_args, type_kw, cleaned_kw = super( + DictField, self).process_type_args(kwargs) + cleaned_kw['default'] = cleaned_kw.get('default') or {} + return type_args, type_kw, cleaned_kw + class ListField(BaseField): _sqla_generic_type = ProcessableChoiceArray @@ -276,6 +282,8 @@ def process_type_args(self, kwargs): type_kw['item_type'] = item_type_field._sqla_generic_type + cleaned_kw['default'] = cleaned_kw.get('default') or [] + return type_args, type_kw, cleaned_kw From ab584c42139cac8b32564c2c12a903c7b2f4b953 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Fri, 24 Apr 2015 17:21:41 +0300 Subject: [PATCH 06/39] Serialize timedelta as seconds --- nefertari_sqla/serializers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nefertari_sqla/serializers.py b/nefertari_sqla/serializers.py index 2d6f756..d6421a1 100644 --- a/nefertari_sqla/serializers.py +++ b/nefertari_sqla/serializers.py @@ -15,7 +15,8 @@ def default(self, obj): if (isinstance(obj, datetime.date) and not isinstance(obj, datetime.datetime)): return obj.strftime('%Y-%m-%d') - + if isinstance(obj, datetime.timedelta): + return obj.seconds if isinstance(obj, decimal.Decimal): return float(obj) @@ -35,7 +36,7 @@ def default(self, data): if isinstance(data, datetime.time): return data.strftime('%H:%M:%S') if isinstance(data, datetime.timedelta): - return str(data) + return data.seconds if isinstance(data, decimal.Decimal): return float(data) try: From 2050f488e964f6be9d6588b7616babe22e58dfbe Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 27 Apr 2015 11:14:35 +0300 Subject: [PATCH 07/39] Encode date, datetime as ISO. Timedelta ni seconds --- nefertari_sqla/serializers.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/nefertari_sqla/serializers.py b/nefertari_sqla/serializers.py index d6421a1..acfb7d0 100644 --- a/nefertari_sqla/serializers.py +++ b/nefertari_sqla/serializers.py @@ -12,9 +12,10 @@ class JSONEncoder(_JSONEncoder): def default(self, obj): - if (isinstance(obj, datetime.date) and - not isinstance(obj, datetime.datetime)): - return obj.strftime('%Y-%m-%d') + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.strftime("%Y-%m-%dT%H:%M:%SZ") # iso + if isinstance(obj, datetime.time): + return obj.strftime('%H:%M:%S') if isinstance(obj, datetime.timedelta): return obj.seconds if isinstance(obj, decimal.Decimal): @@ -29,18 +30,17 @@ def default(self, obj): class ESJSONSerializer(elasticsearch.serializer.JSONSerializer): - def default(self, data): - if (isinstance(data, datetime.date) and - not isinstance(data, datetime.datetime)): - return data.strftime('%Y-%m-%d') - if isinstance(data, datetime.time): - return data.strftime('%H:%M:%S') - if isinstance(data, datetime.timedelta): - return data.seconds - if isinstance(data, decimal.Decimal): - return float(data) + def default(self, obj): + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.strftime("%Y-%m-%dT%H:%M:%SZ") # iso + if isinstance(obj, datetime.time): + return obj.strftime('%H:%M:%S') + if isinstance(obj, datetime.timedelta): + return obj.seconds + if isinstance(obj, decimal.Decimal): + return float(obj) try: - return super(ESJSONSerializer, self).default(data) + return super(ESJSONSerializer, self).default(obj) except: import traceback log.error(traceback.format_exc()) From 86a3908feb40e02a71a9c5639372190b9cbcfeb7 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 27 Apr 2015 11:46:18 +0300 Subject: [PATCH 08/39] Do not index on collection update --- nefertari_sqla/signals.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 0d2116d..5c5a31a 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -2,6 +2,7 @@ from sqlalchemy import event from sqlalchemy.ext.declarative import DeclarativeMeta +from sqlalchemy.orm import object_session log = logging.getLogger(__name__) @@ -20,6 +21,11 @@ def on_after_insert(mapper, connection, target): def on_after_update(mapper, connection, target): + # Do not index on collections update. Use 'ES.index_refs' on + # insert & delete instead. + session = object_session(target) + if not session.is_modified(target, include_collections=False): + return from nefertari.elasticsearch import ES # Reload `target` to get access to processed fields values model_cls = target.__class__ From 593a4dd2bee1ea1bc7351fb46f455ea1d99c24b3 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 27 Apr 2015 15:24:55 +0300 Subject: [PATCH 09/39] Define ID fields with primary_key=True --- nefertari_sqla/tests/test_documents.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 6afaa5a..c67fe93 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -56,7 +56,7 @@ class TestBaseMixin(object): def test_id_field(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - my_id_field = fields.IdField() + my_id_field = fields.IdField(primary_key=True) my_int_field = fields.IntegerField() memory_db() @@ -65,7 +65,7 @@ class MyModel(docs.BaseDocument): def test_check_fields_allowed_not_existing_field(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() memory_db() @@ -78,7 +78,7 @@ class MyModel(docs.BaseDocument): def test_check_fields_allowed(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() memory_db() try: @@ -102,7 +102,7 @@ def test_filter_fields(self, mock_fields): def test_apply_fields(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() desc = fields.StringField() title = fields.StringField() @@ -117,7 +117,7 @@ class MyModel(docs.BaseDocument): def test_apply_fields_no_only_fields(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() desc = fields.StringField() title = fields.StringField() @@ -132,7 +132,7 @@ class MyModel(docs.BaseDocument): def test_apply_fields_no_exclude_fields(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() desc = fields.StringField() title = fields.StringField() @@ -147,7 +147,7 @@ class MyModel(docs.BaseDocument): def test_apply_fields_no_any_fields(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() memory_db() @@ -159,7 +159,7 @@ class MyModel(docs.BaseDocument): def test_apply_sort_no_sort(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() memory_db() @@ -169,7 +169,7 @@ class MyModel(docs.BaseDocument): def test_apply_sort(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - id = fields.IdField() + id = fields.IdField(primary_key=True) name = fields.StringField() memory_db() From 21e81aa81ed68a1c33489524eff979ce41d28390 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 27 Apr 2015 17:01:40 +0300 Subject: [PATCH 10/39] Add test for BaseMixin.test_filter_objects --- nefertari_sqla/tests/test_documents.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index c67fe93..0e04ecb 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -187,3 +187,21 @@ def test_count(self): count = docs.BaseMixin.count(query_set) query_set.count.assert_called_once_with() assert count == 12345 + + @patch.object(docs.BaseMixin, 'get_collection') + def test_filter_objects(self, mock_get, memory_db): + queryset = Mock() + mock_get.return_value = queryset + + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + + MyModel.id.in_ = Mock() + MyModel.filter_objects([Mock(id=4)], first=True) + + mock_get.assert_called_once_with(_limit=1, __raise_on_empty=True) + queryset.from_self.assert_called_once_with() + assert queryset.from_self().filter.call_count == 1 + queryset.from_self().filter().first.assert_called_once_with() + MyModel.id.in_.assert_called_once_with(['4']) From a145f5442429b433adf35663f4b4e2cd15f8858a Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 27 Apr 2015 18:18:37 +0300 Subject: [PATCH 11/39] Add more BaseMixin tests --- nefertari_sqla/documents.py | 2 +- nefertari_sqla/tests/test_documents.py | 78 ++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index b009875..8277caa 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -308,7 +308,7 @@ def native_fields(cls): @classmethod def fields_to_query(cls): query_fields = ['id', '_limit', '_page', '_sort', '_fields', '_count', '_start'] - return query_fields + cls.native_fields() + return list(set(query_fields + cls.native_fields())) @classmethod def get_resource(cls, **params): diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 0e04ecb..5c5ca5d 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -196,6 +196,7 @@ def test_filter_objects(self, mock_get, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' id = fields.IdField(primary_key=True) + memory_db() MyModel.id.in_ = Mock() MyModel.filter_objects([Mock(id=4)], first=True) @@ -205,3 +206,80 @@ class MyModel(docs.BaseDocument): assert queryset.from_self().filter.call_count == 1 queryset.from_self().filter().first.assert_called_once_with() MyModel.id.in_.assert_called_once_with(['4']) + + def test_pop_iterables(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + groups = fields.ListField(item_type=fields.StringField) + settings = fields.DictField() + memory_db() + MyModel.groups.contains = Mock() + MyModel.settings.contains = Mock() + MyModel.settings.has_key = Mock() + MyModel.groups.type.is_postgresql = True + MyModel.settings.type.is_postgresql = True + + params = {'settings': 'foo', 'groups': 'bar', 'id': 1} + iterables, params = MyModel._pop_iterables(params) + assert params == {'id': 1} + assert not MyModel.settings.contains.called + MyModel.settings.has_key.assert_called_once_with('foo') + MyModel.groups.contains.assert_called_once_with(['bar']) + + params = {'settings.foo': 'foo2', 'groups': 'bar', 'id': 1} + iterables, params = MyModel._pop_iterables(params) + assert params == {'id': 1} + assert MyModel.settings.has_key.call_count == 1 + MyModel.settings.contains.assert_called_once_with({'foo': 'foo2'}) + + @patch.object(docs.BaseMixin, 'native_fields') + def test_has_field(self, mock_fields): + mock_fields.return_value = ['foo', 'bar'] + assert docs.BaseMixin.has_field('foo') + assert not docs.BaseMixin.has_field('bazz') + + @patch.object(docs.BaseMixin, 'get_collection') + def test_get_resource(self, mock_get_coll): + queryset = Mock() + mock_get_coll.return_value = queryset + resource = docs.BaseMixin.get_resource(foo='bar') + mock_get_coll.assert_called_once_with( + __raise_on_empty=True, _limit=1, foo='bar') + mock_get_coll().first.assert_called_once_with() + assert resource == mock_get_coll().first() + + def test_native_fields(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + assert MyModel.native_fields() == [ + 'updated_at', '_version', 'id', 'name'] + + def test_fields_to_query(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + assert MyModel.fields_to_query() == [ + '_count', '_start', 'name', '_sort', 'updated_at', + '_version', '_limit', '_fields', 'id', '_page'] + + @patch.object(docs.BaseMixin, 'get_resource') + def test_get(self, get_res): + docs.BaseMixin.get(foo='bar') + get_res.assert_called_once_with( + __raise_on_empty=False, foo='bar') + + def test_unique_fields(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField(unique=True) + desc = fields.StringField() + memory_db() + assert MyModel().unique_fields() == [ + MyModel.id, MyModel.name] From 3aec85207518eed45c5a437d1be21cb04ad3d242 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 28 Apr 2015 13:21:59 +0300 Subject: [PATCH 12/39] Add tests for all documents methods except `get_collection` --- nefertari_sqla/documents.py | 4 +- nefertari_sqla/tests/test_documents.py | 308 +++++++++++++++++++++++++ 2 files changed, 310 insertions(+), 2 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index aca4265..b99bbc3 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -398,7 +398,7 @@ def _update_many(cls, items, **params): item._update(params) def __repr__(self): - parts = ['%s:' % self.__class__.__name__] + parts = [] if hasattr(self, 'id'): parts.append('id=%s' % self.id) @@ -406,7 +406,7 @@ def __repr__(self): if hasattr(self, '_version'): parts.append('v=%s' % self._version) - return '<%s>' % ', '.join(parts) + return '<{}: {}>'.format(self.__class__.__name__, ', '.join(parts)) @classmethod def get_by_ids(cls, ids, **params): diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 5c5ca5d..932e0ce 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -4,6 +4,9 @@ from nefertari.utils.dictset import dictset from nefertari.json_httpexceptions import ( JHTTPBadRequest, JHTTPNotFound, JHTTPConflict) +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.orm.collections import InstrumentedList +from sqlalchemy.exc import IntegrityError from .. import documents as docs from .. import fields @@ -283,3 +286,308 @@ class MyModel(docs.BaseDocument): memory_db() assert MyModel().unique_fields() == [ MyModel.id, MyModel.name] + + @patch.object(docs.BaseMixin, 'get_collection') + def test_get_or_create_existing(self, get_coll, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + + get_coll.return_value = Mock() + one, created = MyModel.get_or_create( + defaults={'foo': 'bar'}, _limit=2, name='q') + get_coll.assert_called_once_with(_limit=2, name='q') + get_coll().one.assert_called_once_with() + assert not created + assert one == get_coll().one() + + @patch.object(docs.BaseMixin, 'get_collection') + def test_get_or_create_existing_multiple(self, get_coll, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + + queryset = Mock() + get_coll.return_value = queryset + queryset.one.side_effect = MultipleResultsFound + with pytest.raises(JHTTPBadRequest) as ex: + one, created = MyModel.get_or_create( + defaults={'foo': 'bar'}, _limit=2, name='q') + assert 'Bad or Insufficient Params' == str(ex.value) + + @patch.object(docs.BaseMixin, 'get_collection') + def test_get_or_create_existing_created(self, get_coll, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + + queryset = Mock() + get_coll.return_value = queryset + queryset.one.side_effect = NoResultFound + one, created = MyModel.get_or_create( + defaults={'id': 7}, _limit=2, name='q') + assert created + assert queryset.session.add.call_count == 1 + assert queryset.session.flush.call_count == 1 + assert one.id == 7 + assert one.name == 'q' + + @patch.object(docs, 'object_session') + def test_underscore_update(self, obj_session, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + settings = fields.DictField() + memory_db() + + myobj = MyModel(id=4, name='foo') + newobj = myobj._update( + {'id': 5, 'name': 'bar', 'settings': {'sett1': 'val1'}}) + obj_session.assert_called_once_with(myobj) + obj_session().add.assert_called_once_with(myobj) + obj_session().flush.assert_called_once_with() + assert newobj.id == 4 + assert newobj.name == 'bar' + assert newobj.settings == {'sett1': 'val1'} + + @patch.object(docs.BaseMixin, 'get') + @patch.object(docs, 'object_session') + def test_underscore_delete(self, obj_session, mock_get): + docs.BaseMixin._delete(foo='bar') + mock_get.assert_called_once_with(foo='bar') + obj_session.assert_called_once_with(mock_get()) + obj_session().delete.assert_called_once_with(mock_get()) + + @patch.object(docs, 'Session') + def test_underscore_delete_many(self, mock_session): + docs.BaseMixin._delete_many(['foo', 'bar']) + mock_session.assert_called_once_with() + mock_session().delete.assert_called_with('bar') + assert mock_session().delete.call_count == 2 + mock_session().flush.assert_called_once_with() + + def test_udnerscore_update_many(self): + item = Mock() + docs.BaseMixin._update_many([item], foo='bar') + item._update.assert_called_once_with({'foo': 'bar'}) + + def test_repr(self): + obj = docs.BaseMixin() + obj.id = 3 + obj._version = 12 + assert str(obj) == '' + + @patch.object(docs.BaseMixin, 'get_collection') + def test_get_by_ids(self, mock_coll, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + name = fields.IdField(primary_key=True) + memory_db() + MyModel.name = Mock() + MyModel.get_by_ids([1, 2, 3], foo='bar') + mock_coll.assert_called_once_with(foo='bar') + MyModel.name.in_.assert_called_once_with([1, 2, 3]) + assert mock_coll().filter.call_count == 1 + mock_coll().filter().limit.assert_called_once_with(3) + + def test_to_dict(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + _nested_relationships = ['other_obj3'] + id = fields.IdField(primary_key=True) + other_obj = fields.StringField() + other_obj2 = fields.StringField() + other_obj3 = fields.StringField() + memory_db() + myobj1 = MyModel(id=1) + myobj1.other_obj = MyModel(id=2) + myobj1.other_obj2 = InstrumentedList([MyModel(id=3)]) + myobj1.other_obj3 = MyModel(id=4) + + result = myobj1.to_dict() + assert list(sorted(result.keys())) == [ + '_type', '_version', 'id', 'other_obj', 'other_obj2', 'other_obj3', + 'updated_at'] + assert result['_type'] == 'MyModel' + assert result['id'] == 1 + # Not nester one-to-one + assert result['other_obj'] == 2 + # Not nester many-to-one + assert result['other_obj2'] == [3] + # Nested one-to-one + assert isinstance(result['other_obj3'], dict) + assert result['other_obj3']['_type'] == 'MyModel' + assert result['other_obj3']['id'] == 4 + + @patch.object(docs, 'object_session') + def test_update_iterables_dict(self, obj_session, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + settings = fields.DictField() + memory_db() + myobj = MyModel(id=1) + + # No existing value + myobj.update_iterables( + {'setting1': 'foo', 'setting2': 'bar', '__boo': 'boo'}, + attr='settings', save=False) + assert not obj_session.called + assert myobj.settings == {'setting1': 'foo', 'setting2': 'bar'} + + # New values to existing value + myobj.update_iterables( + {'-setting1': 'foo', 'setting3': 'baz'}, attr='settings', + save=False) + assert not obj_session.called + assert myobj.settings == {'setting2': 'bar', 'setting3': 'baz'} + + # With save + myobj.update_iterables({}, attr='settings', save=True) + obj_session.assert_called_once_with(myobj) + obj_session().add.assert_called_once_with(myobj) + obj_session().flush.assert_called_once_with() + + @patch.object(docs, 'object_session') + def test_update_iterables_list(self, obj_session, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + settings = fields.ListField(item_type=fields.StringField) + memory_db() + myobj = MyModel(id=1) + + # No existing value + myobj.update_iterables( + {'setting1': '', 'setting2': '', '__boo': 'boo'}, + attr='settings', save=False) + assert not obj_session.called + assert myobj.settings == ['setting1', 'setting2'] + + # New values to existing value + myobj.update_iterables( + {'-setting1': '', 'setting3': ''}, attr='settings', + unique=True, save=False) + assert not obj_session.called + assert myobj.settings == ['setting2', 'setting3'] + + # With save + myobj.update_iterables( + {'setting2': ''}, attr='settings', unique=False, save=True) + assert myobj.settings == ['setting2', 'setting3', 'setting2'] + obj_session.assert_called_once_with(myobj) + obj_session().add.assert_called_once_with(myobj) + obj_session().flush.assert_called_once_with() + + def test_get_reference_documents(self, memory_db): + class Child(docs.BaseDocument): + __tablename__ = 'child' + id = fields.IdField(primary_key=True) + parent_id = fields.ForeignKeyField( + ref_document='Parent', ref_column='parent.id', + ref_column_type=fields.IdField) + class Parent(docs.BaseDocument): + __tablename__ = 'parent' + id = fields.IdField(primary_key=True) + children = fields.Relationship( + document='Child', backref_name='parent') + memory_db() + + parent = Parent(id=1) + child = Child(id=1, parent=parent) + result = [v for v in child.get_reference_documents()] + assert len(result) == 1 + assert result[0][0] is Parent + assert result[0][1] == [parent.to_dict()] + + # 'Many' side of relationship values are not returned + assert child in parent.children + result = [v for v in parent.get_reference_documents()] + assert len(result) == 0 + + +class TestBaseDocument(object): + + def test_bump_version(self, memory_db): + from datetime import datetime + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + memory_db() + + myobj = MyModel(id=None) + assert myobj._version is None + assert myobj.updated_at is None + myobj._bump_version() + assert myobj._version is None + assert myobj.updated_at is None + + myobj.id = 1 + myobj._bump_version() + assert myobj._version == 1 + assert isinstance(myobj.updated_at, datetime) + + @patch.object(docs, 'object_session') + def test_save(self, obj_session, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + memory_db() + + myobj = MyModel(id=4) + newobj = myobj.save() + assert newobj == myobj + assert myobj._version == 1 + obj_session.assert_called_once_with(myobj) + obj_session().add.assert_called_once_with(myobj) + obj_session().flush.assert_called_once_with() + + @patch.object(docs, 'object_session') + def test_save_error(self, obj_session, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + memory_db() + + err = IntegrityError(None, None, None, None) + err.message = 'duplicate' + obj_session().flush.side_effect = err + + with pytest.raises(JHTTPConflict) as ex: + MyModel(id=4).save() + assert 'There was a conflict' in str(ex.value) + + @patch.object(docs.BaseMixin, '_update') + def test_update(self, mock_upd, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + + myobj = MyModel(id=4) + myobj.update({'name': 'q'}) + mock_upd.assert_called_once_with({'name': 'q'}) + + @patch.object(docs.BaseMixin, '_update') + def test_update_error(self, mock_upd, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + memory_db() + + err = IntegrityError(None, None, None, None) + err.message = 'duplicate' + mock_upd.side_effect = err + + with pytest.raises(JHTTPConflict) as ex: + MyModel(id=4).update({'name': 'q'}) + assert 'There was a conflict' in str(ex.value) From c01843f5d89cd74c18e1b33e2fdebd6de00fa25a Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 28 Apr 2015 15:09:07 +0300 Subject: [PATCH 13/39] Add tests for querying with BaseMixin.get_collection --- nefertari_sqla/tests/fixtures.py | 11 ++ nefertari_sqla/tests/test_documents.py | 236 ++++++++++++++----------- 2 files changed, 146 insertions(+), 101 deletions(-) diff --git a/nefertari_sqla/tests/fixtures.py b/nefertari_sqla/tests/fixtures.py index 7d419f4..32ffaee 100644 --- a/nefertari_sqla/tests/fixtures.py +++ b/nefertari_sqla/tests/fixtures.py @@ -25,6 +25,17 @@ def clear(): return creator +@pytest.fixture +def simple_model(request): + from .. import fields, documents as docs + + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + return MyModel + + # Not used yet, because memory_db is called for each test @pytest.fixture def db_session(request): diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 932e0ce..ebbcd3a 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -10,7 +10,7 @@ from .. import documents as docs from .. import fields -from .fixtures import memory_db, db_session +from .fixtures import memory_db, db_session, simple_model class TestDocumentHelpers(object): @@ -65,27 +65,20 @@ class MyModel(docs.BaseDocument): assert MyModel.id_field() == 'my_id_field' - def test_check_fields_allowed_not_existing_field(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_check_fields_allowed_not_existing_field( + self, simple_model, memory_db): memory_db() with pytest.raises(JHTTPBadRequest) as ex: - MyModel.check_fields_allowed(('id__in', 'name', 'description')) + simple_model.check_fields_allowed(('id__in', 'name', 'description')) assert "'MyModel' object does not have fields" in str(ex.value) assert 'description' in str(ex.value) assert 'name' not in str(ex.value) - def test_check_fields_allowed(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_check_fields_allowed(self, simple_model, memory_db): memory_db() try: - MyModel.check_fields_allowed(('id__in', 'name')) + simple_model.check_fields_allowed(('id__in', 'name')) except JHTTPBadRequest: raise Exception('Unexpected JHTTPBadRequest exception raised') @@ -147,42 +140,30 @@ class MyModel(docs.BaseDocument): query_set.with_entities.assert_called_once_with( MyModel.title) - def test_apply_fields_no_any_fields(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_apply_fields_no_any_fields(self, simple_model, memory_db): memory_db() query_set = Mock() _fields = [] - MyModel.apply_fields(query_set, _fields) + simple_model.apply_fields(query_set, _fields) assert not query_set.with_entities.called - def test_apply_sort_no_sort(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_apply_sort_no_sort(self, simple_model, memory_db): memory_db() queryset = ['a', 'b'] - assert MyModel.apply_sort(queryset, []) == queryset + assert simple_model.apply_sort(queryset, []) == queryset - def test_apply_sort(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_apply_sort(self, simple_model, memory_db): memory_db() - MyModel.name.desc = Mock() + simple_model.name.desc = Mock() queryset = Mock() _sort = ['id', '-name'] - MyModel.apply_sort(queryset, _sort) - MyModel.name.desc.assert_called_once_with() + simple_model.apply_sort(queryset, _sort) + simple_model.name.desc.assert_called_once_with() queryset.order_by.assert_called_once_with( - MyModel.id, MyModel.name.desc()) + simple_model.id, simple_model.name.desc()) def test_count(self): query_set = Mock() @@ -192,23 +173,19 @@ def test_count(self): assert count == 12345 @patch.object(docs.BaseMixin, 'get_collection') - def test_filter_objects(self, mock_get, memory_db): - queryset = Mock() - mock_get.return_value = queryset - - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) + def test_filter_objects(self, mock_get, simple_model, memory_db): memory_db() - MyModel.id.in_ = Mock() - MyModel.filter_objects([Mock(id=4)], first=True) + queryset = Mock() + mock_get.return_value = queryset + simple_model.id.in_ = Mock() + simple_model.filter_objects([Mock(id=4)], first=True) mock_get.assert_called_once_with(_limit=1, __raise_on_empty=True) queryset.from_self.assert_called_once_with() assert queryset.from_self().filter.call_count == 1 queryset.from_self().filter().first.assert_called_once_with() - MyModel.id.in_.assert_called_once_with(['4']) + simple_model.id.in_.assert_called_once_with(['4']) def test_pop_iterables(self, memory_db): class MyModel(docs.BaseDocument): @@ -252,22 +229,14 @@ def test_get_resource(self, mock_get_coll): mock_get_coll().first.assert_called_once_with() assert resource == mock_get_coll().first() - def test_native_fields(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_native_fields(self, simple_model, memory_db): memory_db() - assert MyModel.native_fields() == [ + assert simple_model.native_fields() == [ 'updated_at', '_version', 'id', 'name'] - def test_fields_to_query(self, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_fields_to_query(self, simple_model, memory_db): memory_db() - assert MyModel.fields_to_query() == [ + assert simple_model.fields_to_query() == [ '_count', '_start', 'name', '_sort', 'updated_at', '_version', '_limit', '_fields', 'id', '_page'] @@ -288,15 +257,11 @@ class MyModel(docs.BaseDocument): MyModel.id, MyModel.name] @patch.object(docs.BaseMixin, 'get_collection') - def test_get_or_create_existing(self, get_coll, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_get_or_create_existing(self, get_coll, simple_model, memory_db): memory_db() get_coll.return_value = Mock() - one, created = MyModel.get_or_create( + one, created = simple_model.get_or_create( defaults={'foo': 'bar'}, _limit=2, name='q') get_coll.assert_called_once_with(_limit=2, name='q') get_coll().one.assert_called_once_with() @@ -304,33 +269,27 @@ class MyModel(docs.BaseDocument): assert one == get_coll().one() @patch.object(docs.BaseMixin, 'get_collection') - def test_get_or_create_existing_multiple(self, get_coll, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_get_or_create_existing_multiple( + self, get_coll, simple_model, memory_db): memory_db() queryset = Mock() get_coll.return_value = queryset queryset.one.side_effect = MultipleResultsFound with pytest.raises(JHTTPBadRequest) as ex: - one, created = MyModel.get_or_create( + one, created = simple_model.get_or_create( defaults={'foo': 'bar'}, _limit=2, name='q') assert 'Bad or Insufficient Params' == str(ex.value) @patch.object(docs.BaseMixin, 'get_collection') - def test_get_or_create_existing_created(self, get_coll, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_get_or_create_existing_created( + self, get_coll, simple_model, memory_db): memory_db() queryset = Mock() get_coll.return_value = queryset queryset.one.side_effect = NoResultFound - one, created = MyModel.get_or_create( + one, created = simple_model.get_or_create( defaults={'id': 7}, _limit=2, name='q') assert created assert queryset.session.add.call_count == 1 @@ -515,14 +474,11 @@ class Parent(docs.BaseDocument): class TestBaseDocument(object): - def test_bump_version(self, memory_db): + def test_bump_version(self, simple_model, memory_db): from datetime import datetime - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) memory_db() - myobj = MyModel(id=None) + myobj = simple_model(id=None) assert myobj._version is None assert myobj.updated_at is None myobj._bump_version() @@ -535,13 +491,10 @@ class MyModel(docs.BaseDocument): assert isinstance(myobj.updated_at, datetime) @patch.object(docs, 'object_session') - def test_save(self, obj_session, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) + def test_save(self, obj_session, simple_model, memory_db): memory_db() - myobj = MyModel(id=4) + myobj = simple_model(id=4) newobj = myobj.save() assert newobj == myobj assert myobj._version == 1 @@ -550,10 +503,7 @@ class MyModel(docs.BaseDocument): obj_session().flush.assert_called_once_with() @patch.object(docs, 'object_session') - def test_save_error(self, obj_session, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) + def test_save_error(self, obj_session, simple_model, memory_db): memory_db() err = IntegrityError(None, None, None, None) @@ -561,27 +511,19 @@ class MyModel(docs.BaseDocument): obj_session().flush.side_effect = err with pytest.raises(JHTTPConflict) as ex: - MyModel(id=4).save() + simple_model(id=4).save() assert 'There was a conflict' in str(ex.value) @patch.object(docs.BaseMixin, '_update') - def test_update(self, mock_upd, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_update(self, mock_upd, simple_model, memory_db): memory_db() - myobj = MyModel(id=4) + myobj = simple_model(id=4) myobj.update({'name': 'q'}) mock_upd.assert_called_once_with({'name': 'q'}) @patch.object(docs.BaseMixin, '_update') - def test_update_error(self, mock_upd, memory_db): - class MyModel(docs.BaseDocument): - __tablename__ = 'mymodel' - id = fields.IdField(primary_key=True) - name = fields.StringField() + def test_update_error(self, mock_upd, simple_model, memory_db): memory_db() err = IntegrityError(None, None, None, None) @@ -589,5 +531,97 @@ class MyModel(docs.BaseDocument): mock_upd.side_effect = err with pytest.raises(JHTTPConflict) as ex: - MyModel(id=4).update({'name': 'q'}) + simple_model(id=4).update({'name': 'q'}) assert 'There was a conflict' in str(ex.value) + + +class TestGetCollection(object): + + def test_sort_param(self, simple_model, memory_db): + memory_db() + + simple_model(id=1, name='foo').save() + simple_model(id=2, name='bar').save() + + result = simple_model.get_collection(_limit=2, _sort=['-id']) + assert result[0].id == 2 + assert result[1].id == 1 + + result = simple_model.get_collection(_limit=2, _sort=['id']) + assert result[1].id == 2 + assert result[0].id == 1 + + def test_limit_param(self, simple_model, memory_db): + memory_db() + + simple_model(id=1, name='foo').save() + simple_model(id=2, name='bar').save() + + result = simple_model.get_collection(_limit=1, _sort=['id']) + assert result.count() == 1 + assert result[0].id == 1 + + def test_fields_param(self, simple_model, memory_db): + memory_db() + simple_model(id=1, name='foo').save() + result = simple_model.get_collection(_limit=1, _fields=['name']) + assert result.all() == [(u'foo',)] + + def test_offset(self, simple_model, memory_db): + memory_db() + simple_model(id=1, name='foo').save() + simple_model(id=2, name='bar').save() + + result = simple_model.get_collection(_limit=2, _sort=['id'], _start=1) + assert result.count() == 1 + assert result[0].id == 2 + + result = simple_model.get_collection(_limit=1, _sort=['id'], _page=1) + assert result.count() == 1 + assert result[0].id == 2 + + def test_count_param(self, simple_model, memory_db): + memory_db() + simple_model(id=1, name='foo').save() + result = simple_model.get_collection(_limit=2, _count=True) + assert result == 1 + + def test_explain_param(self, simple_model, memory_db): + memory_db() + simple_model(id=1, name='foo').save() + result = simple_model.get_collection(_limit=2, _explain=True) + assert result.startswith('SELECT mymodel') + + def test_strict_param(self, simple_model, memory_db): + memory_db() + simple_model(id=1, name='foo').save() + with pytest.raises(JHTTPBadRequest): + simple_model.get_collection( + _limit=2, __strict=True, name='foo', qwe=1) + + result = simple_model.get_collection( + _limit=2, __strict=False, name='foo', qwe=1) + assert result.all()[0].name == 'foo' + + def test_raise_on_empty_param(self, simple_model, memory_db): + memory_db() + with pytest.raises(JHTTPNotFound): + simple_model.get_collection(_limit=1, __raise_on_empty=True) + + try: + simple_model.get_collection(_limit=1, __raise_on_empty=False) + except JHTTPNotFound: + raise Exception('Unexpected JHTTPNotFound exception') + + def test_queryset_metadata(self, simple_model, memory_db): + memory_db() + simple_model(id=1, name='foo').save() + queryset = simple_model.get_collection(_limit=1) + assert queryset._nefertari_meta['total'] == 1 + assert queryset._nefertari_meta['start'] == 0 + assert queryset._nefertari_meta['fields'] == [] + + def test_no_limit(self, simple_model, memory_db): + memory_db() + with pytest.raises(JHTTPBadRequest): + simple_model.get_collection() From 6b13359e1a7f0891878343b46daa6c42a1d57fc5 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 28 Apr 2015 17:38:47 +0300 Subject: [PATCH 14/39] Add tests for types. Dont check agains choices if choices is None --- nefertari_sqla/tests/test_types.py | 320 +++++++++++++++++++++++++++++ nefertari_sqla/types.py | 27 +-- 2 files changed, 335 insertions(+), 12 deletions(-) create mode 100644 nefertari_sqla/tests/test_types.py diff --git a/nefertari_sqla/tests/test_types.py b/nefertari_sqla/tests/test_types.py new file mode 100644 index 0000000..3ae5216 --- /dev/null +++ b/nefertari_sqla/tests/test_types.py @@ -0,0 +1,320 @@ +import datetime + +import pytest +from mock import patch, Mock +from sqlalchemy.dialects.postgresql import ARRAY, HSTORE + +from .. import documents as docs +from .. import fields +from .. import types +from .fixtures import memory_db, db_session, simple_model + + +class DemoClass(object): + def __init__(self, *args, **kwargs): + pass + + +class TestProcessableMixin(object): + + class Processable(types.ProcessableMixin, DemoClass): + pass + + def test_process_bind_param(self): + processors = [ + lambda v: v.lower(), + lambda v: v.strip(), + lambda v: 'Processed ' + v, + ] + mixin = self.Processable(processors=processors) + value = mixin.process_bind_param(' WeIrd ValUE ', None) + assert value == 'Processed weird value' + + def test_process_bind_param_no_processors(self): + mixin = self.Processable() + value = mixin.process_bind_param(' WeIrd ValUE ', None) + assert value == ' WeIrd ValUE ' + + +class TestLengthLimitedStringMixin(object): + + class Limited(types.LengthLimitedStringMixin, DemoClass): + pass + + def test_none_value(self): + mixin = self.Limited(min_length=5) + try: + mixin.process_bind_param(None, None) + except ValueError: + raise Exception('Unexpected exception') + + def test_min_length(self): + mixin = self.Limited(min_length=5) + with pytest.raises(ValueError) as ex: + mixin.process_bind_param('q', None) + assert str(ex.value) == 'Value length must be more than 5' + try: + mixin.process_bind_param('asdasdasd', None) + except ValueError: + raise Exception('Unexpected exception') + + def test_max_length(self): + mixin = self.Limited(max_length=5) + with pytest.raises(ValueError) as ex: + mixin.process_bind_param('asdasdasdasdasd', None) + assert str(ex.value) == 'Value length must be less than 5' + try: + mixin.process_bind_param('q', None) + except ValueError: + raise Exception('Unexpected exception') + + def test_min_and_max_length(self): + mixin = self.Limited(max_length=5, min_length=2) + with pytest.raises(ValueError) as ex: + mixin.process_bind_param('a', None) + assert str(ex.value) == 'Value length must be more than 2' + with pytest.raises(ValueError) as ex: + mixin.process_bind_param('a12313123123', None) + assert str(ex.value) == 'Value length must be less than 5' + try: + mixin.process_bind_param('12q', None) + except ValueError: + raise Exception('Unexpected exception') + + +class TestSizeLimitedNumberMixin(object): + + class Limited(types.SizeLimitedNumberMixin, DemoClass): + pass + + def test_none_value(self): + mixin = self.Limited(min_value=5) + try: + mixin.process_bind_param(None, None) + except ValueError: + raise Exception('Unexpected exception') + + def test_min_value(self): + mixin = self.Limited(min_value=5) + with pytest.raises(ValueError) as ex: + mixin.process_bind_param(1, None) + assert str(ex.value) == 'Value must be bigger than 5' + try: + mixin.process_bind_param(10, None) + except ValueError: + raise Exception('Unexpected exception') + + def test_max_value(self): + mixin = self.Limited(max_value=5) + with pytest.raises(ValueError) as ex: + mixin.process_bind_param(10, None) + assert str(ex.value) == 'Value must be less than 5' + try: + mixin.process_bind_param(1, None) + except ValueError: + raise Exception('Unexpected exception') + + def test_min_and_max_value(self): + mixin = self.Limited(max_value=5, min_value=2) + with pytest.raises(ValueError) as ex: + mixin.process_bind_param(1, None) + assert str(ex.value) == 'Value must be bigger than 2' + with pytest.raises(ValueError) as ex: + mixin.process_bind_param(10, None) + assert str(ex.value) == 'Value must be less than 5' + try: + mixin.process_bind_param(3, None) + except ValueError: + raise Exception('Unexpected exception') + + +class TestProcessableChoice(object): + + def test_no_choices(self): + field = types.ProcessableChoice() + with pytest.raises(ValueError) as ex: + field.process_bind_param('foo', None) + assert str(ex.value) == 'Got an invalid choice `foo`. Valid choices: ()' + + def test_none_value(self): + field = types.ProcessableChoice() + try: + field.process_bind_param(None, None) + except ValueError: + raise Exception('Unexpected error') + + def test_value_not_in_choices(self): + field = types.ProcessableChoice(choices=['foo']) + with pytest.raises(ValueError) as ex: + field.process_bind_param('bar', None) + assert str(ex.value) == 'Got an invalid choice `bar`. Valid choices: (foo)' + + def test_value_in_choices(self): + field = types.ProcessableChoice(choices=['foo']) + try: + field.process_bind_param('foo', None) + except ValueError: + raise Exception('Unexpected error') + + def test_processed_value_in_choices(self): + field = types.ProcessableChoice( + choices=['foo'], + processors=[lambda v: v.lower()]) + try: + field.process_bind_param('FoO', None) + except ValueError: + raise Exception('Unexpected error') + + def test_choices_not_sequence(self): + field = types.ProcessableChoice(choices='foo') + try: + field.process_bind_param('foo', None) + except ValueError: + raise Exception('Unexpected error') + + +class TestProcessableInterval(object): + + def test_passing_seconds(self): + field = types.ProcessableInterval() + value = field.process_bind_param(36000, None) + assert isinstance(value, datetime.timedelta) + assert value.seconds == 36000 + + def test_passing_timedelta(self): + field = types.ProcessableInterval() + value = field.process_bind_param(datetime.timedelta(seconds=60), None) + assert isinstance(value, datetime.timedelta) + + +class TestProcessableDict(object): + + def test_load_dialect_impl_postgresql(self): + field = types.ProcessableDict() + dialect = Mock() + dialect.name = 'postgresql' + field.load_dialect_impl(dialect=dialect) + assert field.is_postgresql + dialect.type_descriptor.assert_called_once_with(HSTORE) + + def test_load_dialect_impl_not_postgresql(self): + from sqlalchemy.types import UnicodeText + field = types.ProcessableDict() + dialect = Mock() + dialect.name = 'some_other' + field.load_dialect_impl(dialect=dialect) + assert not field.is_postgresql + dialect.type_descriptor.assert_called_once_with(UnicodeText) + + def test_process_bind_param_postgres(self): + field = types.ProcessableDict() + dialect = Mock() + dialect.name = 'postgresql' + assert {'q': 'f'} == field.process_bind_param({'q': 'f'}, dialect) + + def test_process_bind_param_not_postgres(self): + field = types.ProcessableDict() + dialect = Mock() + dialect.name = 'some_other' + assert '{"q": "f"}' == field.process_bind_param({'q': 'f'}, dialect) + + def test_process_result_value_postgres(self): + field = types.ProcessableDict() + dialect = Mock() + dialect.name = 'postgresql' + assert {'q': 'f'} == field.process_result_value({'q': 'f'}, dialect) + + def test_process_result_value_not_postgres(self): + field = types.ProcessableDict() + dialect = Mock() + dialect.name = 'some_other' + assert {'q': 'f'} == field.process_result_value('{"q": "f"}', dialect) + + +class TestProcessableChoiceArray(object): + + @patch.object(types, 'ARRAY') + @patch.object(types.types, 'UnicodeText') + def test_load_dialect_impl_postgresql(self, mock_unic, mock_array): + field = types.ProcessableChoiceArray(item_type=fields.StringField) + dialect = Mock() + dialect.name = 'postgresql' + field.load_dialect_impl(dialect=dialect) + assert field.is_postgresql + assert not mock_unic.called + assert mock_array.called + + @patch.object(types, 'ARRAY') + @patch.object(types.types, 'UnicodeText') + def test_load_dialect_impl_not_postgresql(self, mock_unic, mock_array): + from sqlalchemy.types import UnicodeText + field = types.ProcessableChoiceArray(item_type=fields.StringField) + dialect = Mock() + dialect.name = 'some_other' + field.load_dialect_impl(dialect=dialect) + assert not field.is_postgresql + assert mock_unic.called + assert not mock_array.called + + def test_choices_not_sequence(self): + field = types.ProcessableChoiceArray( + item_type=fields.StringField, choices='foo') + assert field.choices == ['foo'] + + def test_validate_choices_no_choices(self): + field = types.ProcessableChoiceArray(item_type=fields.StringField) + assert field.choices is None + try: + field._validate_choices(['foo']) + except ValueError: + raise Exception('Unexpected error') + + def test_validate_choices_no_value(self): + field = types.ProcessableChoiceArray( + item_type=fields.StringField, choices=['foo']) + try: + field._validate_choices(None) + except ValueError: + raise Exception('Unexpected error') + + def test_validate_choices_valid(self): + field = types.ProcessableChoiceArray( + item_type=fields.StringField, + choices=['foo', 'bar']) + try: + field._validate_choices(['foo']) + except ValueError: + raise Exception('Unexpected error') + + def test_validate_choices_invalid(self): + field = types.ProcessableChoiceArray( + item_type=fields.StringField, + choices=['foo', 'bar']) + with pytest.raises(ValueError) as ex: + field._validate_choices(['qoo', 'foo']) + assert str(ex.value) == ( + 'Got invalid choices: (qoo). Valid choices: (foo, bar)') + + def test_process_bind_param_postgres(self): + field = types.ProcessableChoiceArray(item_type=fields.StringField) + dialect = Mock() + dialect.name = 'postgresql' + assert ['q'] == field.process_bind_param(['q'], dialect) + + def test_process_bind_param_not_postgres(self): + field = types.ProcessableChoiceArray(item_type=fields.StringField) + dialect = Mock() + dialect.name = 'some_other' + assert '["q"]' == field.process_bind_param(['q'], dialect) + + def test_process_result_value_postgres(self): + field = types.ProcessableChoiceArray(item_type=fields.StringField) + dialect = Mock() + dialect.name = 'postgresql' + assert ['q'] == field.process_result_value(['q'], dialect) + + def test_process_result_value_not_postgres(self): + field = types.ProcessableChoiceArray(item_type=fields.StringField) + dialect = Mock() + dialect.name = 'some_other' + assert ['q'] == field.process_result_value('["q"]', dialect) diff --git a/nefertari_sqla/types.py b/nefertari_sqla/types.py index 3226379..325e87a 100644 --- a/nefertari_sqla/types.py +++ b/nefertari_sqla/types.py @@ -130,8 +130,8 @@ class ProcessableChoice(ProcessableMixin, types.TypeDecorator): impl = types.String def __init__(self, *args, **kwargs): - self.choices = kwargs.pop('choices', None) - if not isinstance(self.choices, (list, tuple, list)): + self.choices = kwargs.pop('choices', ()) + if not isinstance(self.choices, (list, tuple, set)): self.choices = [self.choices] super(ProcessableChoice, self).__init__(*args, **kwargs) @@ -188,6 +188,7 @@ def load_dialect_impl(self, dialect): self.is_postgresql = True return dialect.type_descriptor(HSTORE) else: + self.is_postgresql = False return dialect.type_descriptor(types.UnicodeText) def process_bind_param(self, value, dialect): @@ -219,10 +220,11 @@ class ProcessableChoiceArray(ProcessableMixin, types.TypeDecorator): impl = ARRAY def __init__(self, *args, **kwargs): - self.kwargs = kwargs - self.choices = kwargs.pop('choices', ()) or () - if not isinstance(self.choices, (list, tuple, list)): + self.choices = kwargs.pop('choices', None) + if self.choices is not None and not isinstance( + self.choices, (list, tuple, set)): self.choices = [self.choices] + self.kwargs = kwargs super(ProcessableChoiceArray, self).__init__(*args, **kwargs) def load_dialect_impl(self, dialect): @@ -235,6 +237,7 @@ def load_dialect_impl(self, dialect): self.is_postgresql = True return dialect.type_descriptor(ARRAY(**self.kwargs)) else: + self.is_postgresql = False self.kwargs.pop('item_type', None) return dialect.type_descriptor(types.UnicodeText(**self.kwargs)) @@ -242,14 +245,14 @@ def _validate_choices(self, value): """ Perform :value: validation checking if its items are contained in :self.choices: """ - if not self.choices: + if self.choices is None or value is None: return value - if value is not None: - invalid_choices = set(value) - set(self.choices) - if invalid_choices: - raise ValueError( - 'Got invalid choices: ({}). Valid choices: ({})'.format( - ', '.join(invalid_choices), ', '.join(self.choices))) + + invalid_choices = set(value) - set(self.choices) + if invalid_choices: + raise ValueError( + 'Got invalid choices: ({}). Valid choices: ({})'.format( + ', '.join(invalid_choices), ', '.join(self.choices))) return value def process_bind_param(self, value, dialect): From 4f4f629d9aa1706f5808921f48293ae807eb6d70 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Wed, 29 Apr 2015 10:48:59 +0300 Subject: [PATCH 15/39] Use public_fields instead of hidden_fields --- nefertari_sqla/documents.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 1617cf5..fd83f98 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -48,8 +48,8 @@ class BaseMixin(object): Attributes: _auth_fields: String names of fields meant to be displayed to authenticated users. - _hidden_fields: String names of fields meant to be displayed to - admin only + _public_fields: String names of fields meant to be displayed to + non-authenticated users. _nested_fields: ? _nested_relationships: String names of relationship fields that should be included in JSON data of an object as full @@ -57,7 +57,7 @@ class BaseMixin(object): present in this list, this field's value in JSON will be an object's ID or list of IDs. """ - _hidden_fields = None + _public_fields = None _auth_fields = None _nested_fields = None _nested_relationships = () From da826535a21b72c23cf9ca6aab9593067f9c18d9 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Wed, 29 Apr 2015 12:25:21 +0300 Subject: [PATCH 16/39] Replace `id` in JSON with primary key value --- nefertari_sqla/documents.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 1617cf5..ebb6584 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -429,8 +429,7 @@ def to_dict(self, **kwargs): _data[field] = value _dict = DataProxy(_data).to_dict(**kwargs) _dict['_type'] = self._type - if not _dict.get('id'): - _dict['id'] = getattr(self, self.id_field()) + _dict['id'] = getattr(self, self.id_field()) return _dict def update_iterables(self, params, attr, unique=False, From 541b6f38d4740f171564695418d895e332c979f0 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 5 May 2015 12:42:39 +0300 Subject: [PATCH 17/39] Add id_field_type classmethod to Document. Allow passing sqla_generic_type to ForeignKeyField --- nefertari_sqla/documents.py | 4 ++++ nefertari_sqla/fields.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index b528023..6f49f06 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -84,6 +84,10 @@ def id_field(cls): """ Get a primary key field name. """ return class_mapper(cls).primary_key[0].name + @classmethod + def id_field_type(cls): + return class_mapper(cls).primary_key[0].type.__class__ + @classmethod def check_fields_allowed(cls, fields): """ Check if `fields` are allowed to be used on this model. """ diff --git a/nefertari_sqla/fields.py b/nefertari_sqla/fields.py index 342bb31..61a4bce 100644 --- a/nefertari_sqla/fields.py +++ b/nefertari_sqla/fields.py @@ -370,11 +370,13 @@ def __init__(self, *args, **kwargs): Type is determined using 'ref_column_type' value from :kwargs:. Its value must be a *Field class of a field that is being - referenced by FK field. + referenced by FK field or a `_sqla_generic_type` of that *Field cls. """ if not args: field_type = kwargs.pop(self._schema_kwarg_prefix + 'column_type') - self._sqla_generic_type = field_type._sqla_generic_type + if hasattr(field_type, '_sqla_generic_type'): + field_type = field_type._sqla_generic_type + self._sqla_generic_type = field_type super(ForeignKeyField, self).__init__(*args, **kwargs) def _get_referential_action(self, kwargs, key): From 0a3dd14909e29957149702f8144ce80df1286896 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Wed, 6 May 2015 18:24:04 +0300 Subject: [PATCH 18/39] Rename _sqla_generic_type to _sqla_type --- nefertari_sqla/fields.py | 67 ++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/nefertari_sqla/fields.py b/nefertari_sqla/fields.py index 61a4bce..f690bbf 100644 --- a/nefertari_sqla/fields.py +++ b/nefertari_sqla/fields.py @@ -30,17 +30,16 @@ class BaseField(Column): sqlalchemy.Column(sqlalchemy.Type()) Attributes: - _sqla_generic_type: SQLAlchemy generic type class used to instantiate - the column type. + _sqla_type: SQLAlchemy type class used to instantiate the column type. _type_unchanged_kwargs: sequence of strings that represent arguments - received by `_sqla_generic_type` names of which have not been + received by `_sqla_type` names of which have not been changed. Values of field init arguments with these names will be extracted from field init kwargs and passed to Type init as is. _column_valid_kwargs: sequence of string names of valid kwargs that Column may receive. """ - _sqla_generic_type = None + _sqla_type = None _type_unchanged_kwargs = () _column_valid_kwargs = ( 'name', 'type_', 'autoincrement', 'default', 'doc', 'key', 'index', @@ -61,7 +60,7 @@ def __init__(self, *args, **kwargs): col_kw['name'], col_kw['type_'] = args # Column init when defining a schema else: - col_kw['type_'] = self._sqla_generic_type(*type_args, **type_kw) + col_kw['type_'] = self._sqla_type(*type_args, **type_kw) return super(BaseField, self).__init__(**col_kw) def process_type_args(self, kwargs): @@ -115,12 +114,12 @@ def _constructor(self): class BigIntegerField(BaseField): - _sqla_generic_type = LimitedBigInteger + _sqla_type = LimitedBigInteger _type_unchanged_kwargs = ('min_value', 'max_value', 'processors') class BooleanField(BaseField): - _sqla_generic_type = ProcessableBoolean + _sqla_type = ProcessableBoolean _type_unchanged_kwargs = ('create_constraint', 'processors') def process_type_args(self, kwargs): @@ -137,31 +136,31 @@ def process_type_args(self, kwargs): class DateField(BaseField): - _sqla_generic_type = ProcessableDate + _sqla_type = ProcessableDate _type_unchanged_kwargs = ('processors',) class DateTimeField(BaseField): - _sqla_generic_type = ProcessableDateTime + _sqla_type = ProcessableDateTime _type_unchanged_kwargs = ('timezone', 'processors') class ChoiceField(BaseField): - _sqla_generic_type = ProcessableChoice + _sqla_type = ProcessableChoice _type_unchanged_kwargs = ( 'collation', 'convert_unicode', 'unicode_error', '_warn_on_bytestring', 'choices', 'processors') class FloatField(BaseField): - _sqla_generic_type = LimitedFloat + _sqla_type = LimitedFloat _type_unchanged_kwargs = ( 'precision', 'asdecimal', 'decimal_return_scale', 'min_value', 'max_value', 'processors') class IntegerField(BaseField): - _sqla_generic_type = LimitedInteger + _sqla_type = LimitedInteger _type_unchanged_kwargs = ('min_value', 'max_value', 'processors') @@ -173,41 +172,41 @@ class IdField(IntegerField): class IntervalField(BaseField): - _sqla_generic_type = ProcessableInterval + _sqla_type = ProcessableInterval _type_unchanged_kwargs = ( 'native', 'second_precision', 'day_precision', 'processors') class BinaryField(BaseField): - _sqla_generic_type = ProcessableLargeBinary + _sqla_type = ProcessableLargeBinary _type_unchanged_kwargs = ('length', 'processors') # Since SQLAlchemy 1.0.0 # class MatchField(BooleanField): -# _sqla_generic_type = MatchType +# _sqla_type = MatchType class DecimalField(BaseField): - _sqla_generic_type = LimitedNumeric + _sqla_type = LimitedNumeric _type_unchanged_kwargs = ( 'precision', 'scale', 'decimal_return_scale', 'asdecimal', 'min_value', 'max_value', 'processors') class PickleField(BaseField): - _sqla_generic_type = ProcessablePickleType + _sqla_type = ProcessablePickleType _type_unchanged_kwargs = ( 'protocol', 'pickler', 'comparator', 'processors') class SmallIntegerField(BaseField): - _sqla_generic_type = LimitedSmallInteger + _sqla_type = LimitedSmallInteger _type_unchanged_kwargs = ('min_value', 'max_value', 'processors') class StringField(BaseField): - _sqla_generic_type = LimitedString + _sqla_type = LimitedString _type_unchanged_kwargs = ( 'collation', 'convert_unicode', 'unicode_error', '_warn_on_bytestring', 'min_length', 'max_length', @@ -227,23 +226,23 @@ def process_type_args(self, kwargs): class TextField(StringField): - _sqla_generic_type = LimitedText + _sqla_type = LimitedText class TimeField(DateTimeField): - _sqla_generic_type = ProcessableTime + _sqla_type = ProcessableTime class UnicodeField(StringField): - _sqla_generic_type = LimitedUnicode + _sqla_type = LimitedUnicode class UnicodeTextField(StringField): - _sqla_generic_type = LimitedUnicodeText + _sqla_type = LimitedUnicodeText class DictField(BaseField): - _sqla_generic_type = ProcessableDict + _sqla_type = ProcessableDict _type_unchanged_kwargs = () def process_type_args(self, kwargs): @@ -254,12 +253,12 @@ def process_type_args(self, kwargs): class ListField(BaseField): - _sqla_generic_type = ProcessableChoiceArray + _sqla_type = ProcessableChoiceArray _type_unchanged_kwargs = ( 'as_tuple', 'dimensions', 'zero_indexes', 'choices') def process_type_args(self, kwargs): - """ Covert field class to its `_sqla_generic_type`. + """ Covert field class to its `_sqla_type`. StringField & UnicodeField are replaced with corresponding Text fields because when String* fields are used, SQLA creates @@ -280,7 +279,7 @@ def process_type_args(self, kwargs): if item_type_field is UnicodeField: item_type_field = UnicodeTextField - type_kw['item_type'] = item_type_field._sqla_generic_type + type_kw['item_type'] = item_type_field._sqla_type cleaned_kw['default'] = cleaned_kw.get('default') or [] @@ -325,7 +324,7 @@ def __init__(self, *args, **kwargs): column_kw['name'], column_kw['type_'], schema_item = args # Column init when defining a schema else: - column_kw['type_'] = self._sqla_generic_type(*type_args, **type_kw) + column_kw['type_'] = self._sqla_type(*type_args, **type_kw) column_args = (schema_item,) return Column.__init__(self, *column_args, **column_kw) @@ -357,7 +356,7 @@ class ForeignKeyField(BaseSchemaItemField): model to add/update relationship. Use `Relationship` constructor with backreference settings instead. """ - _sqla_generic_type = None + _sqla_type = None _type_unchanged_kwargs = () _schema_class = ForeignKey _schema_kwarg_prefix = 'ref_' @@ -366,17 +365,17 @@ class ForeignKeyField(BaseSchemaItemField): 'ondelete', 'deferrable', 'initially', 'link_to_name', 'match') def __init__(self, *args, **kwargs): - """ Override to determine `self._sqla_generic_type`. + """ Override to determine `self._sqla_type`. Type is determined using 'ref_column_type' value from :kwargs:. Its value must be a *Field class of a field that is being - referenced by FK field or a `_sqla_generic_type` of that *Field cls. + referenced by FK field or a `_sqla_type` of that *Field cls. """ if not args: field_type = kwargs.pop(self._schema_kwarg_prefix + 'column_type') - if hasattr(field_type, '_sqla_generic_type'): - field_type = field_type._sqla_generic_type - self._sqla_generic_type = field_type + if hasattr(field_type, '_sqla_type'): + field_type = field_type._sqla_type + self._sqla_type = field_type super(ForeignKeyField, self).__init__(*args, **kwargs) def _get_referential_action(self, kwargs, key): From 9df6d46d1edd30c41f06c5ffa4531dd4332dae41 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 11 May 2015 13:16:17 +0300 Subject: [PATCH 19/39] Fix get_by_ids --- nefertari_sqla/documents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 6f49f06..6d2092b 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -416,7 +416,7 @@ def __repr__(self): def get_by_ids(cls, ids, **params): query_set = cls.get_collection(**params) cls_id = getattr(cls, cls.id_field()) - return query_set.filter(cls_id.in_(ids)).limit(len(ids)) + return query_set.from_self().filter(cls_id.in_(ids)).limit(len(ids)) def to_dict(self, **kwargs): native_fields = self.__class__.native_fields() From c12fb138671362af417bfdecfacd9b3c644da7f6 Mon Sep 17 00:00:00 2001 From: Jonathan Stoikovitch Date: Mon, 11 May 2015 10:39:36 -0400 Subject: [PATCH 20/39] Adding tox and travis --- .travis.yml | 7 +++++++ tox.ini | 15 +++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 .travis.yml create mode 100644 tox.ini diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..7d17c19 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,7 @@ +# Config file for automatic testing at travis-ci.org +language: python +env: + - TOXENV=py27 +install: + - pip install tox +script: tox diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..bae77e4 --- /dev/null +++ b/tox.ini @@ -0,0 +1,15 @@ +[tox] +envlist = py27 + +[testenv] +setenv = + PYTHONHASHSEED=0 +deps = -rrequirements.dev +commands = py.test {posargs:--cov nefertari_sqla} + +[testenv:flake8] +deps = + flake8==2.3.0 + pep8==1.6.2 +commands = + flake8 nefertari_sqla From dca8051b53b38d4eba05dc7a2ceb7e4efc5426a1 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 11 May 2015 17:50:44 +0300 Subject: [PATCH 21/39] Fix filter_objects to query by IDs before other args. Allow passing queryset to get_collection --- nefertari_sqla/documents.py | 26 ++++++++++------- nefertari_sqla/tests/test_documents.py | 40 +++++++++++++++++++------- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 6d2092b..dfd139c 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -158,26 +158,29 @@ def filter_objects(cls, objects, first=False, **params): :object: Sequence of :cls: instances on which query should be run. :params: Query parameters. """ - if first: - params['_limit'] = 1 - params['__raise_on_empty'] = True - queryset = cls.get_collection(**params) - id_name = cls.id_field() ids = [getattr(obj, id_name, None) for obj in objects] ids = [str(id_) for id_ in ids if id_ is not None] field_obj = getattr(cls, id_name) - queryset = queryset.from_self().filter(field_obj.in_(ids)) + + session = Session() + query_set = session.query(cls).filter(field_obj.in_(ids)) + + if first: + params['_limit'] = 1 + params['__raise_on_empty'] = True + params['query_set'] = query_set.from_self() + query_set = cls.get_collection(**params) if first: - first_obj = queryset.first() + first_obj = query_set.first() if not first_obj: msg = "'{}({}={})' resource not found".format( cls.__name__, id_name, params[id_name]) raise JHTTPNotFound(msg) return first_obj - return queryset + return query_set @classmethod def _pop_iterables(cls, params): @@ -245,12 +248,14 @@ def get_collection(cls, **params): _limit = params.pop('_limit', None) _page = params.pop('_page', None) _start = params.pop('_start', None) + query_set = params.pop('query_set', None) _count = '_count' in params; params.pop('_count', None) _explain = '_explain' in params; params.pop('_explain', None) __raise_on_empty = params.pop('__raise_on_empty', False) - session = Session() + if query_set is None: + query_set = Session().query(cls) # Remove any __ legacy instructions from this point on params = dictset(filter(lambda item: not item[0].startswith('__'), params.items())) @@ -270,7 +275,8 @@ def get_collection(cls, **params): params.pop_by_values('_all') try: - query_set = session.query(cls).filter_by(**params) + + query_set = query_set.filter_by(**params) # Apply filtering by iterable expressions for expr in iterables_exprs: diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index ebbcd3a..1ee9048 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -172,19 +172,22 @@ def test_count(self): query_set.count.assert_called_once_with() assert count == 12345 + @patch.object(docs, 'Session') @patch.object(docs.BaseMixin, 'get_collection') - def test_filter_objects(self, mock_get, simple_model, memory_db): + def test_filter_objects(self, mock_get, mock_sess, simple_model, memory_db): memory_db() - - queryset = Mock() - mock_get.return_value = queryset + queryset1 = mock_sess().query().filter() + queryset2 = Mock() + mock_get.return_value = queryset2 simple_model.id.in_ = Mock() simple_model.filter_objects([Mock(id=4)], first=True) - mock_get.assert_called_once_with(_limit=1, __raise_on_empty=True) - queryset.from_self.assert_called_once_with() - assert queryset.from_self().filter.call_count == 1 - queryset.from_self().filter().first.assert_called_once_with() + mock_sess().query.assert_called_with(simple_model) + assert mock_sess().query().filter.call_count == 2 + + mock_get.assert_called_once_with( + _limit=1, __raise_on_empty=True, + query_set=queryset1.from_self()) simple_model.id.in_.assert_called_once_with(['4']) def test_pop_iterables(self, memory_db): @@ -353,8 +356,8 @@ class MyModel(docs.BaseDocument): MyModel.get_by_ids([1, 2, 3], foo='bar') mock_coll.assert_called_once_with(foo='bar') MyModel.name.in_.assert_called_once_with([1, 2, 3]) - assert mock_coll().filter.call_count == 1 - mock_coll().filter().limit.assert_called_once_with(3) + assert mock_coll().from_self().filter.call_count == 1 + mock_coll().from_self().filter().limit.assert_called_once_with(3) def test_to_dict(self, memory_db): class MyModel(docs.BaseDocument): @@ -537,6 +540,23 @@ def test_update_error(self, mock_upd, simple_model, memory_db): class TestGetCollection(object): + def test_input_queryset(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField() + foo = fields.StringField() + memory_db() + MyModel(id=1, name='foo', foo=2).save() + MyModel(id=2, name='boo', foo=2).save() + MyModel(id=3, name='boo', foo=1).save() + queryset1 = MyModel.get_collection(_limit=50, name='boo') + assert queryset1.count() == 2 + queryset2 = MyModel.get_collection( + _limit=50, foo=2, query_set=queryset1.from_self()) + assert queryset2.count() == 1 + assert queryset2.first().id == 2 + def test_sort_param(self, simple_model, memory_db): memory_db() From 2d10babdbb5ffac887413866a2e4f11f4516f63c Mon Sep 17 00:00:00 2001 From: Jonathan Stoikovitch Date: Mon, 11 May 2015 11:01:16 -0400 Subject: [PATCH 22/39] adding badges --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 93838bc..7053c57 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # `nefertari-sqla` +[![Build Status](https://travis-ci.org/brandicted/nefertari-sqla.svg?branch=master)](https://travis-ci.org/brandicted/nefertari-sqla) +[![Documentation Status](https://readthedocs.org/projects/nefertari-sqla/badge/?version=stable)](http://nefertari-sqla.readthedocs.org/en/stable/) + SQLA backend for Nefertari From b96c4e2c5ec22f7ffafa007f5f6fd22c72484c4d Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 12 May 2015 08:23:54 +0300 Subject: [PATCH 23/39] Delete nested_fields Document attr --- nefertari_sqla/documents.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index dfd139c..ecff698 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -50,7 +50,6 @@ class BaseMixin(object): authenticated users. _public_fields: String names of fields meant to be displayed to non-authenticated users. - _nested_fields: ? _nested_relationships: String names of relationship fields that should be included in JSON data of an object as full included documents. If relationship field is not @@ -59,7 +58,6 @@ class BaseMixin(object): """ _public_fields = None _auth_fields = None - _nested_fields = None _nested_relationships = () _type = property(lambda self: self.__class__.__name__) From 69dbbbfe84d56b61b3cddb2e23016e93bf93c1fc Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 12 May 2015 08:27:17 +0300 Subject: [PATCH 24/39] Rename id_field -> pk_field --- nefertari_sqla/documents.py | 18 +++++++++--------- nefertari_sqla/signals.py | 8 ++++---- nefertari_sqla/tests/test_documents.py | 6 +++--- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index ecff698..4037449 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -78,12 +78,12 @@ def generate(mapper, connection, target): event.listen(model, 'after_insert', generate) @classmethod - def id_field(cls): + def pk_field(cls): """ Get a primary key field name. """ return class_mapper(cls).primary_key[0].name @classmethod - def id_field_type(cls): + def pk_field_type(cls): return class_mapper(cls).primary_key[0].type.__class__ @classmethod @@ -156,7 +156,7 @@ def filter_objects(cls, objects, first=False, **params): :object: Sequence of :cls: instances on which query should be run. :params: Query parameters. """ - id_name = cls.id_field() + id_name = cls.pk_field() ids = [getattr(obj, id_name, None) for obj in objects] ids = [str(id_) for id_ in ids if id_ is not None] field_obj = getattr(cls, id_name) @@ -372,11 +372,11 @@ def _update(self, params, **kw): iter_fields = set( k for k, v in fields.items() if isinstance(v, (DictField, ListField))) - id_field = self.id_field() + pk_field = self.pk_field() for key, value in params.items(): # Can't change PK field - if key == id_field: + if key == pk_field: continue if key in iter_fields: self.update_iterables(value, key, unique=True, save=False) @@ -419,7 +419,7 @@ def __repr__(self): @classmethod def get_by_ids(cls, ids, **params): query_set = cls.get_collection(**params) - cls_id = getattr(cls, cls.id_field()) + cls_id = getattr(cls, cls.pk_field()) return query_set.from_self().filter(cls_id.in_(ids)).limit(len(ids)) def to_dict(self, **kwargs): @@ -429,7 +429,7 @@ def to_dict(self, **kwargs): value = getattr(self, field, None) include = field in self._nested_relationships if not include: - get_id = lambda v: getattr(v, v.id_field(), None) + get_id = lambda v: getattr(v, v.pk_field(), None) if isinstance(value, BaseMixin): value = get_id(value) elif isinstance(value, InstrumentedList): @@ -437,7 +437,7 @@ def to_dict(self, **kwargs): _data[field] = value _dict = DataProxy(_data).to_dict(**kwargs) _dict['_type'] = self._type - _dict['id'] = getattr(self, self.id_field()) + _dict['id'] = getattr(self, self.pk_field()) return _dict def update_iterables(self, params, attr, unique=False, @@ -533,7 +533,7 @@ class BaseDocument(BaseObject, BaseMixin): _version = IntegerField(default=0) def _bump_version(self): - if getattr(self, self.id_field(), None): + if getattr(self, self.pk_field(), None): self.updated_at = datetime.utcnow() self._version = (self._version or 0) + 1 diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 5c5a31a..5520f08 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -13,8 +13,8 @@ def on_after_insert(mapper, connection, target): # Reload `target` to get access to back references and processed # fields values model_cls = target.__class__ - id_field = target.id_field() - reloaded = model_cls.get(**{id_field: getattr(target, id_field)}) + pk_field = target.pk_field() + reloaded = model_cls.get(**{pk_field: getattr(target, pk_field)}) es = ES(model_cls.__name__) es.index(reloaded.to_dict()) es.index_refs(reloaded) @@ -29,8 +29,8 @@ def on_after_update(mapper, connection, target): from nefertari.elasticsearch import ES # Reload `target` to get access to processed fields values model_cls = target.__class__ - id_field = target.id_field() - reloaded = model_cls.get(**{id_field: getattr(target, id_field)}) + pk_field = target.pk_field() + reloaded = model_cls.get(**{pk_field: getattr(target, pk_field)}) es = ES(reloaded.__class__.__name__) es.index(reloaded.to_dict()) es.index_refs(reloaded) diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 1ee9048..209822e 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -56,14 +56,14 @@ def test_process_bools(self): class TestBaseMixin(object): - def test_id_field(self, memory_db): + def test_pk_field(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - my_id_field = fields.IdField(primary_key=True) + my_pk_field = fields.IdField(primary_key=True) my_int_field = fields.IntegerField() memory_db() - assert MyModel.id_field() == 'my_id_field' + assert MyModel.pk_field() == 'my_pk_field' def test_check_fields_allowed_not_existing_field( self, simple_model, memory_db): From 968ba34b17874761b3e11ecfb449e93cbb5c32d1 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 12 May 2015 09:38:48 +0300 Subject: [PATCH 25/39] Fix flake8 errors --- docs/source/conf.py | 122 ++++++++++++------------- nefertari_sqla/documents.py | 33 ++++--- nefertari_sqla/tests/test_documents.py | 9 +- nefertari_sqla/tests/test_types.py | 10 +- setup.py | 4 +- 5 files changed, 97 insertions(+), 81 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1b32a6a..ea04850 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,19 +12,17 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os -import shlex # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -42,7 +40,7 @@ source_suffix = '.rst' # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' @@ -70,9 +68,9 @@ # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -80,27 +78,27 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -108,7 +106,8 @@ # -- Options for HTML output ---------------------------------------------- -# on_rtd is whether we are on readthedocs.org, this line of code grabbed from docs.readthedocs.org +# on_rtd is whether we are on readthedocs.org, this line of code grabbed +# from docs.readthedocs.org on_rtd = os.environ.get('READTHEDOCS', None) == 'True' if not on_rtd: # only import and set the theme if we're building docs locally @@ -116,7 +115,8 @@ html_theme = 'sphinx_rtd_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -# otherwise, readthedocs.org uses their theme by default, so no need to specify it +# otherwise, readthedocs.org uses their theme by default, so no need to +# specify it # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. @@ -125,26 +125,26 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -154,62 +154,62 @@ # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'Nefertaridoc' @@ -217,46 +217,46 @@ # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', -# Latex figure (float) alignment -#'figure_align': 'htbp', + # Latex figure (float) alignment + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Nefertari.tex', u'Nefertari Documentation', - u'Brandicted', 'manual'), + (master_doc, 'Nefertari.tex', u'Nefertari Documentation', + u'Brandicted', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- @@ -269,7 +269,7 @@ ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -278,19 +278,19 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Nefertari', u'Nefertari Documentation', - author, 'Nefertari', 'One line description of project.', - 'Miscellaneous'), + (master_doc, 'Nefertari', u'Nefertari Documentation', + author, 'Nefertari', 'One line description of project.', + 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 4037449..49be322 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -122,9 +122,11 @@ def apply_fields(cls, query_set, _fields): fields_exclude = fields_exclude or [] if fields_exclude: # Remove fields_exclude from fields_only - fields_only = [f for f in fields_only if f not in fields_exclude] + fields_only = [ + f for f in fields_only if f not in fields_exclude] if fields_only: - fields_only = [getattr(cls, f) for f in sorted(set(fields_only))] + fields_only = [ + getattr(cls, f) for f in sorted(set(fields_only))] query_set = query_set.with_entities(*fields_only) except InvalidRequestError as e: @@ -248,20 +250,24 @@ def get_collection(cls, **params): _start = params.pop('_start', None) query_set = params.pop('query_set', None) - _count = '_count' in params; params.pop('_count', None) - _explain = '_explain' in params; params.pop('_explain', None) + _count = '_count' in params + params.pop('_count', None) + _explain = '_explain' in params + params.pop('_explain', None) __raise_on_empty = params.pop('__raise_on_empty', False) if query_set is None: query_set = Session().query(cls) # Remove any __ legacy instructions from this point on - params = dictset(filter(lambda item: not item[0].startswith('__'), params.items())) + params = dictset(filter( + lambda item: not item[0].startswith('__'), params.items())) iterables_exprs, params = cls._pop_iterables(params) if __strict: - _check_fields = [f.strip('-+') for f in params.keys() + _fields + _sort] + _check_fields = [ + f.strip('-+') for f in params.keys() + _fields + _sort] cls.check_fields_allowed(_check_fields) else: params = cls.filter_fields(params) @@ -289,7 +295,8 @@ def get_collection(cls, **params): _start, _limit = process_limit(_start, _page, _limit) - # Filtering by fields has to be the first thing to do on the query_set! + # Filtering by fields has to be the first thing to do on + # the query_set! query_set = cls.apply_fields(query_set, _fields) query_set = cls.apply_sort(query_set, _sort) query_set = query_set.offset(_start).limit(_limit) @@ -330,7 +337,8 @@ def native_fields(cls): @classmethod def fields_to_query(cls): - query_fields = ['id', '_limit', '_page', '_sort', '_fields', '_count', '_start'] + query_fields = [ + 'id', '_limit', '_page', '_sort', '_fields', '_count', '_start'] return list(set(query_fields + cls.native_fields())) @classmethod @@ -342,7 +350,8 @@ def get_resource(cls, **params): @classmethod def get(cls, **kw): - return cls.get_resource(__raise_on_empty=kw.pop('__raise', False), **kw) + return cls.get_resource( + __raise_on_empty=kw.pop('__raise', False), **kw) def unique_fields(self): native_fields = class_mapper(self.__class__).columns @@ -550,7 +559,8 @@ def save(self, *arg, **kw): raise # Other error, not duplicate raise JHTTPConflict( - detail='Resource `%s` already exists.' % self.__class__.__name__, + detail='Resource `{}` already exists.'.format( + self.__class__.__name__), extra={'data': e}) def update(self, params): @@ -562,7 +572,8 @@ def update(self, params): raise # other error, not duplicate raise JHTTPConflict( - detail='Resource `%s` already exists.' % self.__class__.__name__, + detail='Resource `{}` already exists.'.format( + self.__class__.__name__), extra={'data': e}) diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 209822e..7411423 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -70,7 +70,8 @@ def test_check_fields_allowed_not_existing_field( memory_db() with pytest.raises(JHTTPBadRequest) as ex: - simple_model.check_fields_allowed(('id__in', 'name', 'description')) + simple_model.check_fields_allowed(( + 'id__in', 'name', 'description')) assert "'MyModel' object does not have fields" in str(ex.value) assert 'description' in str(ex.value) assert 'name' not in str(ex.value) @@ -174,7 +175,8 @@ def test_count(self): @patch.object(docs, 'Session') @patch.object(docs.BaseMixin, 'get_collection') - def test_filter_objects(self, mock_get, mock_sess, simple_model, memory_db): + def test_filter_objects( + self, mock_get, mock_sess, simple_model, memory_db): memory_db() queryset1 = mock_sess().query().filter() queryset2 = Mock() @@ -449,17 +451,20 @@ class MyModel(docs.BaseDocument): obj_session().flush.assert_called_once_with() def test_get_reference_documents(self, memory_db): + class Child(docs.BaseDocument): __tablename__ = 'child' id = fields.IdField(primary_key=True) parent_id = fields.ForeignKeyField( ref_document='Parent', ref_column='parent.id', ref_column_type=fields.IdField) + class Parent(docs.BaseDocument): __tablename__ = 'parent' id = fields.IdField(primary_key=True) children = fields.Relationship( document='Child', backref_name='parent') + memory_db() parent = Parent(id=1) diff --git a/nefertari_sqla/tests/test_types.py b/nefertari_sqla/tests/test_types.py index 3ae5216..eb79310 100644 --- a/nefertari_sqla/tests/test_types.py +++ b/nefertari_sqla/tests/test_types.py @@ -2,9 +2,8 @@ import pytest from mock import patch, Mock -from sqlalchemy.dialects.postgresql import ARRAY, HSTORE +from sqlalchemy.dialects.postgresql import HSTORE -from .. import documents as docs from .. import fields from .. import types from .fixtures import memory_db, db_session, simple_model @@ -134,7 +133,8 @@ def test_no_choices(self): field = types.ProcessableChoice() with pytest.raises(ValueError) as ex: field.process_bind_param('foo', None) - assert str(ex.value) == 'Got an invalid choice `foo`. Valid choices: ()' + assert str(ex.value) == \ + 'Got an invalid choice `foo`. Valid choices: ()' def test_none_value(self): field = types.ProcessableChoice() @@ -147,7 +147,8 @@ def test_value_not_in_choices(self): field = types.ProcessableChoice(choices=['foo']) with pytest.raises(ValueError) as ex: field.process_bind_param('bar', None) - assert str(ex.value) == 'Got an invalid choice `bar`. Valid choices: (foo)' + assert str(ex.value) == \ + 'Got an invalid choice `bar`. Valid choices: (foo)' def test_value_in_choices(self): field = types.ProcessableChoice(choices=['foo']) @@ -247,7 +248,6 @@ def test_load_dialect_impl_postgresql(self, mock_unic, mock_array): @patch.object(types, 'ARRAY') @patch.object(types.types, 'UnicodeText') def test_load_dialect_impl_not_postgresql(self, mock_unic, mock_array): - from sqlalchemy.types import UnicodeText field = types.ProcessableChoiceArray(item_type=fields.StringField) dialect = Mock() dialect.name = 'some_other' diff --git a/setup.py b/setup.py index 3dfd598..b9dbbc2 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ -import os - from setuptools import setup, find_packages + install_requires = [ 'sqlalchemy', 'zope.dottedname', @@ -13,6 +12,7 @@ 'nefertari==0.2.1' ] + setup( name='nefertari_sqla', version="0.1.1", From d1ac0ca654b1e96d4cdadbf88770be86c8926e94 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Wed, 13 May 2015 13:21:29 +0300 Subject: [PATCH 26/39] Expire objects on save and in signals --- nefertari_sqla/documents.py | 3 ++- nefertari_sqla/signals.py | 15 ++++++++------- nefertari_sqla/tests/test_documents.py | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 49be322..f5840dc 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -412,7 +412,7 @@ def _delete_many(cls, items): @classmethod def _update_many(cls, items, **params): for item in items: - item._update(params) + item.update(params) def __repr__(self): parts = [] @@ -553,6 +553,7 @@ def save(self, *arg, **kw): try: session.add(self) session.flush() + session.expire(self) return self except (IntegrityError,) as e: if 'duplicate' not in e.message: diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 5520f08..74b69e7 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -15,6 +15,7 @@ def on_after_insert(mapper, connection, target): model_cls = target.__class__ pk_field = target.pk_field() reloaded = model_cls.get(**{pk_field: getattr(target, pk_field)}) + es = ES(model_cls.__name__) es.index(reloaded.to_dict()) es.index_refs(reloaded) @@ -26,14 +27,14 @@ def on_after_update(mapper, connection, target): session = object_session(target) if not session.is_modified(target, include_collections=False): return - from nefertari.elasticsearch import ES + # Reload `target` to get access to processed fields values - model_cls = target.__class__ - pk_field = target.pk_field() - reloaded = model_cls.get(**{pk_field: getattr(target, pk_field)}) - es = ES(reloaded.__class__.__name__) - es.index(reloaded.to_dict()) - es.index_refs(reloaded) + session.expire(target) + + from nefertari.elasticsearch import ES + es = ES(target.__class__.__name__) + es.index(target.to_dict()) + es.index_refs(target) def on_after_delete(mapper, connection, target): diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 7411423..e6ec1f6 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -337,10 +337,10 @@ def test_underscore_delete_many(self, mock_session): assert mock_session().delete.call_count == 2 mock_session().flush.assert_called_once_with() - def test_udnerscore_update_many(self): + def test_underscore_update_many(self): item = Mock() docs.BaseMixin._update_many([item], foo='bar') - item._update.assert_called_once_with({'foo': 'bar'}) + item.update.assert_called_once_with({'foo': 'bar'}) def test_repr(self): obj = docs.BaseMixin() From 7277e66a58012c395a3a0659db0a9c80b3f2a097 Mon Sep 17 00:00:00 2001 From: Chris Hart Date: Thu, 14 May 2015 11:11:36 -0400 Subject: [PATCH 27/39] minor spelling/grammar tweaks --- nefertari_sqla/__init__.py | 2 +- nefertari_sqla/documents.py | 18 +++++++++--------- nefertari_sqla/fields.py | 29 +++++++++++++++-------------- nefertari_sqla/types.py | 8 ++++---- nefertari_sqla/utils.py | 6 +++--- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/nefertari_sqla/__init__.py b/nefertari_sqla/__init__.py index d56f6d3..e4f44de 100644 --- a/nefertari_sqla/__init__.py +++ b/nefertari_sqla/__init__.py @@ -47,7 +47,7 @@ def includeme(config): def setup_database(config): - """ Setup db engine, db itself. Create db if not exists. """ + """ Setup db engine, db itself. Create db if it doesn't exist. """ from sqlalchemy import engine_from_config from sqlalchemy_utils import database_exists, create_database from pyramid_sqlalchemy import BaseObject diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index f5840dc..e115cfd 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -68,7 +68,7 @@ def autogenerate_for(cls, model, set_to): Event handler is registered for class :model: and creates a new instance of :cls: with a field :set_to: set to an instance on - which event occured. + which the event occured. """ from sqlalchemy import event @@ -187,10 +187,10 @@ def _pop_iterables(cls, params): """ Pop iterable fields' parameters from :params: and generate SQLA expressions to query the database. - Iterable values are found by checking what keys from :params: + Iterable values are found by checking which keys from :params: correspond to names of Dict/List fields on model. - In case ListField uses `postgresql.ARRAY` type, value is - wrapped in list. + If ListField uses the `postgresql.ARRAY` type, the value is + wrapped in a list. """ from .fields import ListField, DictField iterables = {} @@ -235,9 +235,9 @@ def _pop_iterables(cls, params): @classmethod def get_collection(cls, **params): """ - params may include '_limit', '_page', '_sort', '_fields' - returns paginated and sorted query set - raises JHTTPBadRequest for bad values in params + Params may include '_limit', '_page', '_sort', '_fields'. + Returns paginated and sorted query set. + Raises JHTTPBadRequest for bad values in params. """ log.debug('Get collection: {}, {}'.format(cls.__name__, params)) params.pop('__confirmation', False) @@ -533,7 +533,7 @@ def get_reference_documents(self): class BaseDocument(BaseObject, BaseMixin): """ Base class for SQLA models. - Subclasses of this class that do not define model schema, + Subclasses of this class that do not define a model schema should be abstract as well (__abstract__ = True). """ __abstract__ = True @@ -581,7 +581,7 @@ def update(self, params): class ESBaseDocument(BaseDocument): """ Base class for SQLA models that use Elasticsearch. - Subclasses of this class that do not define model schema, + Subclasses of this class that do not define a model schema should be abstract as well (__abstract__ = True). """ __abstract__ = True diff --git a/nefertari_sqla/fields.py b/nefertari_sqla/fields.py index f690bbf..a703f33 100644 --- a/nefertari_sqla/fields.py +++ b/nefertari_sqla/fields.py @@ -32,12 +32,12 @@ class BaseField(Column): Attributes: _sqla_type: SQLAlchemy type class used to instantiate the column type. _type_unchanged_kwargs: sequence of strings that represent arguments - received by `_sqla_type` names of which have not been + received by `_sqla_type`, the names of which have not been changed. Values of field init arguments with these names will be extracted from field init kwargs and passed to Type init as is. _column_valid_kwargs: sequence of string names of valid kwargs that - Column may receive. + a Column may receive. """ _sqla_type = None _type_unchanged_kwargs = () @@ -70,7 +70,7 @@ def process_type_args(self, kwargs): Process `kwargs` to extract type-specific arguments. If some arguments' names should be changed, extend this method - with a manual args processing. + with a manual args processor. Returns: * type_args: sequence of type-specific posional arguments @@ -87,8 +87,8 @@ def process_type_args(self, kwargs): def _drop_invalid_kwargs(self, kwargs): """ Drop keys from `kwargs` that are not present in - `self._column_valid_kwargs`, thus are not valid kwargs that - may be passed to Column. + `self._column_valid_kwargs`, thus are not valid kwargs to + be passed to Column. """ return {k: v for k, v in kwargs.items() if k in self._column_valid_kwargs} @@ -290,7 +290,7 @@ class BaseSchemaItemField(BaseField): """ Base class for fields/columns that accept a schema item/constraint on column init. E.g. Column(Integer, ForeignKey('user.id')) - It differs from a regular columns in a way that item/constr passed to + It differs from regular columns in that an item/constraint passed to the Column on init has to be passed as a positional argument and should also receive arguments. Thus 3 objects need to be created on init: Column, Type, and SchemaItem/Constraint. @@ -298,7 +298,7 @@ class BaseSchemaItemField(BaseField): Attributes: _schema_class: Class to be instantiated to create a schema item. _schema_kwarg_prefix: Prefix schema item's kwargs should have. This - is used to not make a mess, as both column, type and schemaitem + is used to avoid making a mess, as both column, type and schemaitem kwargs may be passed at once. _schema_valid_kwargs: Sequence of strings that represent names of kwargs `_schema_class` may receive. Should not include prefix. @@ -346,11 +346,11 @@ class ForeignKeyField(BaseSchemaItemField): """ Integer ForeignKey field. This is the place where `ondelete` rules kwargs should be passed. - If you switched from mongodb engine, copy here the same `ondelete` + If you switched from the mongodb engine, copy the same `ondelete` rules you passed to mongo's `Relationship` constructor. - `ondelete` kwargs may be kept in both fields with no side-effect - when switching between sqla-mongo engines. + `ondelete` kwargs may be kept in both fields with no side-effects + when switching between the sqla and mongo engines. Developers are not encouraged to change the value of this field on model to add/update relationship. Use `Relationship` constructor @@ -382,14 +382,15 @@ def _get_referential_action(self, kwargs, key): """ Determine/translate generic rule name to SQLA-specific rule. Output rule name is a valid SQL Referential action name. - If `ondelete` kwarg is not provided, no ref. action will be created. + If `ondelete` kwarg is not provided, no referential action will be + created. Valid kwargs for `ondelete` kwarg are: CASCADE Translates to SQL as `CASCADE` RESTRICT Translates to SQL as `RESTRICT` NULLIFY Translates to SQL as `SET NULL - Not supported SQL ref. actions: `NO ACTION`, `SET DEFAULT` + Not supported SQL referential actions: `NO ACTION`, `SET DEFAULT` """ key = self._schema_kwarg_prefix + key action = kwargs.pop(key, None) @@ -407,7 +408,7 @@ def _get_referential_action(self, kwargs, key): return rules[action] def _generate_schema_item(self, cleaned_kw): - """ Override default implementation to generate 'ondelete', 'onupdate' + """ Override default implementation to generate 'ondelete' and 'onupdate' arguments. """ pref = self._schema_kwarg_prefix @@ -439,7 +440,7 @@ def Relationship(**kwargs): The goal of this wrapper is to allow passing both relationship and backref arguments to a single function. Backref arguments should be prefixed with 'backref_'. - Function splits relationship-specific and backref-specific arguments + This function splits relationship-specific and backref-specific arguments and makes a call like: relationship(..., ..., backref=backref(...)) """ diff --git a/nefertari_sqla/types.py b/nefertari_sqla/types.py index 325e87a..4d99c6c 100644 --- a/nefertari_sqla/types.py +++ b/nefertari_sqla/types.py @@ -63,22 +63,22 @@ def process_bind_param(self, value, dialect): class LimitedString(LengthLimitedStringMixin, types.TypeDecorator): - """ String type, min and max length if which may be limited. """ + """ String type, min and max length limits. """ impl = types.String class LimitedText(LengthLimitedStringMixin, types.TypeDecorator): - """ Text type, min and max length if which may be limited. """ + """ Text type, min and max length limits. """ impl = types.Text class LimitedUnicode(LengthLimitedStringMixin, types.TypeDecorator): - """ Unicode type, min and max length if which may be limited. """ + """ Unicode type, min and max length limits. """ impl = types.Unicode class LimitedUnicodeText(LengthLimitedStringMixin, types.TypeDecorator): - """ UnicideText type, min and max length if which may be limited. """ + """ UnicideText type, min and max length limits. """ impl = types.UnicodeText diff --git a/nefertari_sqla/utils.py b/nefertari_sqla/utils.py index 3dc9517..817beb4 100644 --- a/nefertari_sqla/utils.py +++ b/nefertari_sqla/utils.py @@ -17,11 +17,11 @@ def is_relationship_field(field, model_cls): def relationship_cls(field, model_cls): - """ Return class which is pointed to by relationship field + """ Return class that is pointed to by relationship field `field` from model `model_cls`. - You have to make sure field exists and is a relationship - field by yourself. Use `is_relationship_field` for these purposes. + Make sure field exists and is a relationship + field manually. Use `is_relationship_field` for this. """ mapper = class_mapper(model_cls) relationships = {r.key: r for r in mapper.relationships} From 6da7b028ac8613e08f4997f84343273123ca04b2 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Fri, 15 May 2015 10:50:10 +0300 Subject: [PATCH 28/39] Rename relationship_cls -> get_relationship_cls --- nefertari_sqla/__init__.py | 2 +- nefertari_sqla/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nefertari_sqla/__init__.py b/nefertari_sqla/__init__.py index d56f6d3..5f9175e 100644 --- a/nefertari_sqla/__init__.py +++ b/nefertari_sqla/__init__.py @@ -11,7 +11,7 @@ from .signals import ESMetaclass from .utils import ( relationship_fields, is_relationship_field, - relationship_cls) + get_relationship_cls) from .fields import ( BigIntegerField, BooleanField, diff --git a/nefertari_sqla/utils.py b/nefertari_sqla/utils.py index 3dc9517..a03f0c7 100644 --- a/nefertari_sqla/utils.py +++ b/nefertari_sqla/utils.py @@ -16,7 +16,7 @@ def is_relationship_field(field, model_cls): return isinstance(field_obj, relationship_fields) -def relationship_cls(field, model_cls): +def get_relationship_cls(field, model_cls): """ Return class which is pointed to by relationship field `field` from model `model_cls`. From 6dd61f508ca9a88e18226ba30c3c76a82f21e937 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Fri, 15 May 2015 17:06:20 +0300 Subject: [PATCH 29/39] Fix one2one fields not indexed. Fix deletion signal --- nefertari_sqla/signals.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 74b69e7..872b1a5 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -2,7 +2,7 @@ from sqlalchemy import event from sqlalchemy.ext.declarative import DeclarativeMeta -from sqlalchemy.orm import object_session +from sqlalchemy.orm import object_session, class_mapper log = logging.getLogger(__name__) @@ -29,7 +29,8 @@ def on_after_update(mapper, connection, target): return # Reload `target` to get access to processed fields values - session.expire(target) + attributes = [c.name for c in class_mapper(target.__class__).columns] + session.expire(target, attribute_names=attributes) from nefertari.elasticsearch import ES es = ES(target.__class__.__name__) @@ -39,8 +40,10 @@ def on_after_update(mapper, connection, target): def on_after_delete(mapper, connection, target): from nefertari.elasticsearch import ES - es = ES(target.__class__.__name__) - es.delete(target.id) + model_cls = target.__class__ + es = ES(model_cls.__name__) + obj_id = getattr(target, model_cls.pk_field()) + es.delete(obj_id) es.index_refs(target) From 3c80e103c82e79f16ceb7b13faff11fd097cc547 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 18 May 2015 11:15:08 +0300 Subject: [PATCH 30/39] Fix duplicate object collection items indexing --- nefertari_sqla/documents.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index e115cfd..a8b100a 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -436,12 +436,15 @@ def to_dict(self, **kwargs): _data = {} for field in native_fields: value = getattr(self, field, None) + is_objects_list = isinstance(value, InstrumentedList) + if is_objects_list: + value = list(set(value)) include = field in self._nested_relationships if not include: get_id = lambda v: getattr(v, v.pk_field(), None) if isinstance(value, BaseMixin): value = get_id(value) - elif isinstance(value, InstrumentedList): + elif is_objects_list: value = [get_id(val) for val in value] _data[field] = value _dict = DataProxy(_data).to_dict(**kwargs) From 81c52eb9c610affb5aa8016f999e3c7c0d193b9c Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 18 May 2015 14:55:40 +0300 Subject: [PATCH 31/39] Index old parent object when child changes his parent --- nefertari_sqla/documents.py | 15 ++++++++++----- nefertari_sqla/signals.py | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index a8b100a..1f3a427 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -13,7 +13,7 @@ from nefertari.utils import ( process_fields, process_limit, _split, dictset, DataProxy) -from .signals import ESMetaclass +from .signals import ESMetaclass, index_object from .fields import DateTimeField, IntegerField, DictField, ListField log = logging.getLogger(__name__) @@ -383,14 +383,17 @@ def _update(self, params, **kw): if isinstance(v, (DictField, ListField))) pk_field = self.pk_field() - for key, value in params.items(): + for key, new_value in params.items(): # Can't change PK field if key == pk_field: continue + old_value = getattr(self, key, None) if key in iter_fields: - self.update_iterables(value, key, unique=True, save=False) + self.update_iterables(new_value, key, unique=True, save=False) else: - setattr(self, key, value) + setattr(self, key, new_value) + if isinstance(old_value, BaseMixin) and old_value != new_value: + index_object(old_value, with_refs=False) session = object_session(self) session.add(self) @@ -436,7 +439,9 @@ def to_dict(self, **kwargs): _data = {} for field in native_fields: value = getattr(self, field, None) - is_objects_list = isinstance(value, InstrumentedList) + is_objects_list = ( + isinstance(value, (InstrumentedList, list)) and value and + isinstance(value[0], BaseMixin)) if is_objects_list: value = list(set(value)) include = field in self._nested_relationships diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 872b1a5..80f92bb 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -8,17 +8,22 @@ log = logging.getLogger(__name__) -def on_after_insert(mapper, connection, target): +def index_object(obj, with_refs=True): from nefertari.elasticsearch import ES + es = ES(obj.__class__.__name__) + es.index(obj.to_dict()) + if with_refs: + es.index_refs(obj) + + +def on_after_insert(mapper, connection, target): # Reload `target` to get access to back references and processed # fields values model_cls = target.__class__ pk_field = target.pk_field() reloaded = model_cls.get(**{pk_field: getattr(target, pk_field)}) - es = ES(model_cls.__name__) - es.index(reloaded.to_dict()) - es.index_refs(reloaded) + index_object(reloaded) def on_after_update(mapper, connection, target): @@ -27,15 +32,10 @@ def on_after_update(mapper, connection, target): session = object_session(target) if not session.is_modified(target, include_collections=False): return - # Reload `target` to get access to processed fields values attributes = [c.name for c in class_mapper(target.__class__).columns] session.expire(target, attribute_names=attributes) - - from nefertari.elasticsearch import ES - es = ES(target.__class__.__name__) - es.index(target.to_dict()) - es.index_refs(target) + index_object(target) def on_after_delete(mapper, connection, target): From 58e54893d407f6dbc96696e92a115cf19bdd2103 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 18 May 2015 15:10:36 +0300 Subject: [PATCH 32/39] Refresh objects in get_reference_documents --- nefertari_sqla/documents.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 1f3a427..3a0ef3f 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -439,17 +439,12 @@ def to_dict(self, **kwargs): _data = {} for field in native_fields: value = getattr(self, field, None) - is_objects_list = ( - isinstance(value, (InstrumentedList, list)) and value and - isinstance(value[0], BaseMixin)) - if is_objects_list: - value = list(set(value)) include = field in self._nested_relationships if not include: get_id = lambda v: getattr(v, v.pk_field(), None) if isinstance(value, BaseMixin): value = get_id(value) - elif is_objects_list: + elif isinstance(value, InstrumentedList): value = [get_id(val) for val in value] _data[field] = value _dict = DataProxy(_data).to_dict(**kwargs) @@ -535,6 +530,8 @@ def get_reference_documents(self): # If 'Many' side should be indexed, its value is already a list. if value is None or isinstance(value, list): continue + session = object_session(value) + session.refresh(value) yield (value.__class__, [value.to_dict()]) From df4785671b54278113fca24c4fdd36677bd80ece Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 18 May 2015 15:17:17 +0300 Subject: [PATCH 33/39] Fix failing test --- nefertari_sqla/documents.py | 3 +++ nefertari_sqla/tests/test_documents.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 3a0ef3f..0db337a 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -392,6 +392,9 @@ def _update(self, params, **kw): self.update_iterables(new_value, key, unique=True, save=False) else: setattr(self, key, new_value) + + # Trigger reindexation of old value in case it is a DB object and + # it is changed to other object if isinstance(old_value, BaseMixin) and old_value != new_value: index_object(old_value, with_refs=False) diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index e6ec1f6..4f77973 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -450,7 +450,8 @@ class MyModel(docs.BaseDocument): obj_session().add.assert_called_once_with(myobj) obj_session().flush.assert_called_once_with() - def test_get_reference_documents(self, memory_db): + @patch.object(docs, 'object_session') + def test_get_reference_documents(self, mock_sess, memory_db): class Child(docs.BaseDocument): __tablename__ = 'child' @@ -474,6 +475,9 @@ class Parent(docs.BaseDocument): assert result[0][0] is Parent assert result[0][1] == [parent.to_dict()] + mock_sess.assert_called_with(parent) + mock_sess().refresh.assert_called_with(parent) + # 'Many' side of relationship values are not returned assert child in parent.children result = [v for v in parent.get_reference_documents()] From 39403f78821ff9c186c94211c80182b8d0300e56 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 18 May 2015 15:26:31 +0300 Subject: [PATCH 34/39] Remove with_refs=False fron document update --- nefertari_sqla/documents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 0db337a..82cf773 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -396,7 +396,7 @@ def _update(self, params, **kw): # Trigger reindexation of old value in case it is a DB object and # it is changed to other object if isinstance(old_value, BaseMixin) and old_value != new_value: - index_object(old_value, with_refs=False) + index_object(old_value) session = object_session(self) session.add(self) From b8433b97e4aab1846dd05222d815bc64d2bdb5ae Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Mon, 18 May 2015 19:02:52 +0300 Subject: [PATCH 35/39] Add initial definition of remove event handler --- nefertari_sqla/signals.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 80f92bb..590f247 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -47,10 +47,21 @@ def on_after_delete(mapper, connection, target): es.index_refs(target) +def on_collection_item_remove(target, value, initiator): + index_object(target) + + def setup_es_signals_for(source_cls): event.listen(source_cls, 'after_insert', on_after_insert) event.listen(source_cls, 'after_update', on_after_update) event.listen(source_cls, 'after_delete', on_after_delete) + + # relationships = {r.key: r for r in class_mapper(source_cls).relationships} + # for name, rel in relationships.items(): + # if rel.uselist: + # field_obj = getattr(source_cls, name) + # event.listen(field_obj, 'remove', on_collection_item_remove) + log.info('setup_sqla_es_signals_for: %r' % source_cls) From cdbed30404e217efb3799fcfe416135e19cfdbb1 Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 19 May 2015 12:34:40 +0300 Subject: [PATCH 36/39] Fix update signals handling --- nefertari_sqla/documents.py | 6 ------ nefertari_sqla/fields.py | 9 +++++++++ nefertari_sqla/signals.py | 13 +------------ 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 82cf773..5d02ae4 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -387,17 +387,11 @@ def _update(self, params, **kw): # Can't change PK field if key == pk_field: continue - old_value = getattr(self, key, None) if key in iter_fields: self.update_iterables(new_value, key, unique=True, save=False) else: setattr(self, key, new_value) - # Trigger reindexation of old value in case it is a DB object and - # it is changed to other object - if isinstance(old_value, BaseMixin) and old_value != new_value: - index_object(old_value) - session = object_session(self) session.add(self) session.flush() diff --git a/nefertari_sqla/fields.py b/nefertari_sqla/fields.py index a703f33..ce9398b 100644 --- a/nefertari_sqla/fields.py +++ b/nefertari_sqla/fields.py @@ -443,6 +443,11 @@ def Relationship(**kwargs): This function splits relationship-specific and backref-specific arguments and makes a call like: relationship(..., ..., backref=backref(...)) + + :lazy: setting is set to 'joined' on the 'One' side of One2One or + One2Many relationships. This is done both for relationship itself + and backref so ORM 'after_update' events are fired when relationship + is updated. """ backref_pre = 'backref_' kwargs['doc'] = kwargs.pop('help_text', None) @@ -459,7 +464,11 @@ def Relationship(**kwargs): else: rel_kw[key] = val rel_document = rel_kw.pop('document') + if not rel_kw.get('uselist'): + rel_kw['lazy'] = 'joined' if backref_kw: + if not backref_kw.get('uselist'): + backref_kw['lazy'] = 'joined' backref_name = backref_kw.pop('name') rel_kw['backref'] = backref(backref_name, **backref_kw) return relationship(rel_document, **rel_kw) diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index 590f247..ab87ff3 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -22,16 +22,12 @@ def on_after_insert(mapper, connection, target): model_cls = target.__class__ pk_field = target.pk_field() reloaded = model_cls.get(**{pk_field: getattr(target, pk_field)}) - index_object(reloaded) def on_after_update(mapper, connection, target): - # Do not index on collections update. Use 'ES.index_refs' on - # insert & delete instead. session = object_session(target) - if not session.is_modified(target, include_collections=False): - return + # Reload `target` to get access to processed fields values attributes = [c.name for c in class_mapper(target.__class__).columns] session.expire(target, attribute_names=attributes) @@ -55,13 +51,6 @@ def setup_es_signals_for(source_cls): event.listen(source_cls, 'after_insert', on_after_insert) event.listen(source_cls, 'after_update', on_after_update) event.listen(source_cls, 'after_delete', on_after_delete) - - # relationships = {r.key: r for r in class_mapper(source_cls).relationships} - # for name, rel in relationships.items(): - # if rel.uselist: - # field_obj = getattr(source_cls, name) - # event.listen(field_obj, 'remove', on_collection_item_remove) - log.info('setup_sqla_es_signals_for: %r' % source_cls) From b7fc135cbf31022795bced50184f7275d670aace Mon Sep 17 00:00:00 2001 From: Artem Kostiuk Date: Tue, 19 May 2015 12:36:57 +0300 Subject: [PATCH 37/39] Delete not used code --- nefertari_sqla/documents.py | 2 +- nefertari_sqla/signals.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 5d02ae4..571a495 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -13,7 +13,7 @@ from nefertari.utils import ( process_fields, process_limit, _split, dictset, DataProxy) -from .signals import ESMetaclass, index_object +from .signals import ESMetaclass from .fields import DateTimeField, IntegerField, DictField, ListField log = logging.getLogger(__name__) diff --git a/nefertari_sqla/signals.py b/nefertari_sqla/signals.py index ab87ff3..8546089 100644 --- a/nefertari_sqla/signals.py +++ b/nefertari_sqla/signals.py @@ -43,10 +43,6 @@ def on_after_delete(mapper, connection, target): es.index_refs(target) -def on_collection_item_remove(target, value, initiator): - index_object(target) - - def setup_es_signals_for(source_cls): event.listen(source_cls, 'after_insert', on_after_insert) event.listen(source_cls, 'after_update', on_after_update) From 87fe038400c2d1a42ca7386ae5213760598e74b9 Mon Sep 17 00:00:00 2001 From: Chris Hart Date: Tue, 19 May 2015 10:18:35 -0400 Subject: [PATCH 38/39] released 0.2.0 --- docs/source/conf.py | 8 ++- docs/source/index.rst | 139 +++--------------------------------------- setup.py | 4 +- 3 files changed, 16 insertions(+), 135 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index ea04850..9f1d14b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -29,8 +29,12 @@ # ones. extensions = [ 'sphinx.ext.autodoc', + 'releases' ] +releases_github_path = 'brandicted/nefertari-sqla' +releases_debug = True + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -55,9 +59,9 @@ # built documents. # # The short X.Y version. -version = '1.0' +# version = '0.1' # The full version, including alpha/beta/rc tags. -release = '1.0' +# release = '1.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index db88180..370e297 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,136 +1,13 @@ SQLA Engine =========== -Documents ---------- +Index +----- -.. autoclass:: nefertari_sqla.documents.BaseMixin - :members: - :special-members: - :private-members: +.. toctree:: + :maxdepth: 2 -.. autoclass:: nefertari_sqla.documents.BaseDocument - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.documents.ESBaseDocument - :members: - :special-members: - :private-members: - -Serializers ------------ - -.. autoclass:: nefertari_sqla.serializers.JSONEncoder - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.serializers.ESJSONSerializer - :members: - :special-members: - :private-members: - -Fields ------- - -.. autoclass:: nefertari_sqla.fields.IntegerField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.BigIntegerField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.SmallIntegerField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.BooleanField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.DateField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.DateTimeField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.FloatField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.StringField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.TextField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.UnicodeField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.UnicodeTextField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.ChoiceField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.BinaryField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.DecimalField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.TimeField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.PickleField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.IntervalField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.IdField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.ForeignKeyField - :members: - :special-members: - :private-members: - -.. autoclass:: nefertari_sqla.fields.Relationship - :members: - :special-members: - :private-members: + base_classes + serializers + fields + changelog \ No newline at end of file diff --git a/setup.py b/setup.py index b9dbbc2..c0e3b56 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,13 @@ 'sqlalchemy_utils', 'elasticsearch', 'pyramid_tm', - 'nefertari==0.2.1' + 'nefertari==0.3.0' ] setup( name='nefertari_sqla', - version="0.1.1", + version="0.2.0", description='sqla engine for nefertari', classifiers=[ "Programming Language :: Python", From 01beb7487c0b7c9605d7a9ff7654f706962d330c Mon Sep 17 00:00:00 2001 From: Chris Hart Date: Tue, 19 May 2015 10:18:59 -0400 Subject: [PATCH 39/39] docs refactor --- docs/source/base_classes.rst | 17 ++++++ docs/source/changelog.rst | 10 ++++ docs/source/fields.rst | 102 +++++++++++++++++++++++++++++++++++ docs/source/serializers.rst | 12 +++++ 4 files changed, 141 insertions(+) create mode 100644 docs/source/base_classes.rst create mode 100644 docs/source/changelog.rst create mode 100644 docs/source/fields.rst create mode 100644 docs/source/serializers.rst diff --git a/docs/source/base_classes.rst b/docs/source/base_classes.rst new file mode 100644 index 0000000..85f5b98 --- /dev/null +++ b/docs/source/base_classes.rst @@ -0,0 +1,17 @@ +Base Classes +------------ + +.. autoclass:: nefertari_sqla.documents.BaseMixin + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.documents.BaseDocument + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.documents.ESBaseDocument + :members: + :special-members: + :private-members: \ No newline at end of file diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst new file mode 100644 index 0000000..3fedb45 --- /dev/null +++ b/docs/source/changelog.rst @@ -0,0 +1,10 @@ +Changelog +========= + +* :release:`0.2.0 <2015-04-07>` +* :feature:`-` Relationship indexing + +* :release:`0.1.1 <2015-04-01>` + +* :release:`0.1.0 <2015-04-01>` + diff --git a/docs/source/fields.rst b/docs/source/fields.rst new file mode 100644 index 0000000..7b8b8db --- /dev/null +++ b/docs/source/fields.rst @@ -0,0 +1,102 @@ +Fields +------ + +.. autoclass:: nefertari_sqla.fields.IntegerField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.BigIntegerField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.SmallIntegerField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.BooleanField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.DateField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.DateTimeField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.FloatField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.StringField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.TextField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.UnicodeField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.UnicodeTextField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.ChoiceField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.BinaryField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.DecimalField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.TimeField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.PickleField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.IntervalField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.IdField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.ForeignKeyField + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.fields.Relationship + :members: + :special-members: + :private-members: \ No newline at end of file diff --git a/docs/source/serializers.rst b/docs/source/serializers.rst new file mode 100644 index 0000000..7f918a9 --- /dev/null +++ b/docs/source/serializers.rst @@ -0,0 +1,12 @@ +Serializers +----------- + +.. autoclass:: nefertari_sqla.serializers.JSONEncoder + :members: + :special-members: + :private-members: + +.. autoclass:: nefertari_sqla.serializers.ESJSONSerializer + :members: + :special-members: + :private-members: \ No newline at end of file