diff --git a/nefertari_mongodb/documents.py b/nefertari_mongodb/documents.py index 8624760..13e68d2 100644 --- a/nefertari_mongodb/documents.py +++ b/nefertari_mongodb/documents.py @@ -527,19 +527,32 @@ def update_iterables(self, params, attr, unique=False, is_dict = isinstance(type(self)._fields[attr], mongo.DictField) is_list = isinstance(type(self)._fields[attr], mongo.ListField) + if is_list: + is_list_of_dicts=isinstance(type(self)._fields[attr].item_type, type(mongo.DictField)) def split_keys(keys): neg_keys = [] pos_keys = [] - + self_keys = [] + if is_list_of_dicts: + self_keys= getattr(self,attr) for key in keys: + #edited to support dicts in a array + if isinstance(key,dict): + if key in getattr(self,attr): + self_keys.remove(key) + pos_keys.append(key) + continue if key.startswith('__'): continue if key.startswith('-'): neg_keys.append(key[1:]) else: pos_keys.append(key.strip()) + if self_keys: + neg_keys.extend(self_keys) return pos_keys, neg_keys + def update_dict(update_params): final_value = getattr(self, attr, {}) or {} final_value = final_value.copy() @@ -585,7 +598,10 @@ def update_list(update_params): final_value += positive if negative: - final_value = list(set(final_value) - set(negative)) + if is_list_of_dicts: + [final_value.remove(i) for i in final_value if i in negative] + else: + final_value = list(set(final_value) - set(negative)) setattr(self, attr, final_value) if save: @@ -597,6 +613,12 @@ def update_list(update_params): elif is_list: update_list(params) + if is_dict: + update_dict(params) + + elif is_list: + update_list(params) + @classmethod def expand_with(cls, with_cls, join_on=None, attr_name=None, params={}, with_params={}):