From ea13824a6f50b6619f14ca65dd30cc92e20d2b2c Mon Sep 17 00:00:00 2001 From: Andras Tim Date: Thu, 26 Nov 2015 08:39:46 +0100 Subject: [PATCH] server: Added change detection to Base views * Refactored ModelDataDiffer * Use distinct ModelDataDiffer instance for views (super class variables are initialized only once) * Followed up Base views changes in Work view --- server/app/modules/view_helper_for_models.py | 54 +++++++++++++++----- server/app/views/base_views.py | 32 +++++++++--- server/app/views/work.py | 5 +- 3 files changed, 66 insertions(+), 25 deletions(-) diff --git a/server/app/modules/view_helper_for_models.py b/server/app/modules/view_helper_for_models.py index 4112ae52..f70a2b68 100644 --- a/server/app/modules/view_helper_for_models.py +++ b/server/app/modules/view_helper_for_models.py @@ -1,4 +1,5 @@ import re +from enum import Enum, unique from flask import request from marshmallow import Serializer from sqlalchemy.exc import IntegrityError @@ -133,46 +134,71 @@ def get_validated_request(deserializer: Serializer) -> (dict, list, None): return data +@unique +class ModelDataStates(Enum): + original = 'original' + populated = 'populated' + + class ModelDataDiffer: """ Tool for checking changes of the specified fields """ def __init__(self): + self.__fields = set() self.__nested_fields = {} - self.__original_values = {} - def save_state(self, model: db.Model): - self.__initial_enumerate_nested_fields(model) + self.__states = {} + for state in ModelDataStates: + self.__states[state.value] = {} + + def save_state(self, model: (db.Model, None), state: ModelDataStates): + state_object = self.__states[state] = {} + if model is None: + return + + self.__initial_enumerate_fields(model) + for field_name in self.__fields: + state_object[field_name] = self.__get_field_value(model, field_name) - self.__original_values = {} - for column in model.__table__.columns: - self.__original_values[column.name] = self.__get_field_value(model, column.name) + def get_state(self, state: ModelDataStates) -> dict: + return self.__states[state] - def get_diff(self, model: db.Model) -> dict: + def diff_states(self) -> dict: diff = {} + original_values = self.__states[ModelDataStates.original] + populated_values = self.__states[ModelDataStates.populated] - for column in model.__table__.columns: - original = self.__original_values[column.name] - current = self.__get_field_value(model, column.name) + for field_name in self.__fields: + original = original_values[field_name] + current = populated_values[field_name] if original != current: - diff[column.name] = {'original': original, 'current': current} + diff[field_name] = {'original': original, 'current': current} return diff - def __initial_enumerate_nested_fields(self, model: db.Model): - if self.__nested_fields: + def __initial_enumerate_fields(self, model: db.Model): + if self.__fields: return + for field in model.__table__.columns: + self.__fields.add(field.name) + relationships = inspect(model.__class__).relationships for relationship in relationships: nested_field_name = relationship.key for local_field, remote_field in relationship.local_remote_pairs: - self.__nested_fields[local_field.name] = [nested_field_name, remote_field.name] + if local_field.foreign_keys: + self.__nested_fields[local_field.name] = [nested_field_name, remote_field.name] def __get_field_value(self, model: db.Model, name: str): if name not in self.__nested_fields.keys(): return getattr(model, name) nested_field_name, remote_field_name = self.__nested_fields[name] + nested_field = getattr(model, nested_field_name) + if nested_field is None: + return None + return getattr(nested_field, remote_field_name) diff --git a/server/app/views/base_views.py b/server/app/views/base_views.py index e2962b8e..f8e362e0 100644 --- a/server/app/views/base_views.py +++ b/server/app/views/base_views.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Query from sqlalchemy.exc import IntegrityError -from app.modules.view_helper_for_models import PopulateModelOnSubmit, ModelDataDiffer, SqlErrorParser +from app.modules.view_helper_for_models import PopulateModelOnSubmit, ModelDataDiffer, SqlErrorParser, ModelDataStates from app.server import db from app.views.common import commit_with_error_handling, commit_and_rollback_on_error @@ -13,7 +13,10 @@ class _BaseModelResource(Resource): _parent_model = None _serializer = None _deserializer = None - __differ = ModelDataDiffer() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__differ = ModelDataDiffer() @property def _query(self) -> Query: @@ -33,11 +36,19 @@ def _serialize_many(self, items) -> list: def _initialize_parent_item(self, parent_id: int) -> 'item': return _initialize_parent_item(self._parent_model, parent_id) - def _save_original_before_populate(self, item): - self.__differ.save_state(item) + def _save_original_state(self, item=None): + self.__differ.save_state(item, ModelDataStates.original) + + def _save_populated_state(self, item=None): + self.__differ.save_state(item, ModelDataStates.populated) + + def _get_population_changes(self) -> dict: + return self.__differ.diff_states() - def _get_populate_diff(self, item) -> dict: - return self.__differ.get_diff(item) + def _create_populated_changelog(self): + original = self.__differ.get_state(ModelDataStates.original) + populated = self.__differ.get_state(ModelDataStates.populated) + print(original, populated) class BaseListView(_BaseModelResource): @@ -81,10 +92,12 @@ def _post_populate(self, **set_fields) -> 'item': """ Populate a new object """ + self._save_original_state() item = self._model() for name, value in set_fields.items(): setattr(item, name, value) self._populate_item(item) + self._save_populated_state(item) return item def _post_commit(self, item) -> 'RPC response': @@ -167,7 +180,9 @@ def _put_populate(self, **filter) -> 'item': Populate change object """ item = self._get_item_by_filter(**filter) + self._save_original_state(item) self._populate_item(item) + self._save_populated_state(item) return item def _put_commit(self, item) -> 'RPC response': @@ -181,7 +196,10 @@ def _delete_get_item(self, **filter) -> 'item': """ Getting object for delete """ - return self._get_item_by_filter(**filter) + item = self._get_item_by_filter(**filter) + self._save_original_state(item) + self._save_populated_state() + return item def _delete_commit(self, item) -> None: """ diff --git a/server/app/views/work.py b/server/app/views/work.py index 8da6e66c..ffbaac89 100644 --- a/server/app/views/work.py +++ b/server/app/views/work.py @@ -100,11 +100,8 @@ def get(self, id: int, item_id: int): def put(self, id: int, item_id: int): work = self._initialize_parent_item(id) - original_item = self._get_item_by_id(item_id) - self._save_original_before_populate(original_item) - modified_item = self._put_populate(work_id=id, id=item_id) - changed_fields = self._get_populate_diff(modified_item).keys() + changed_fields = self._get_population_changes().keys() if self.__is_tried_to_change_closed(work, changed_fields): abort(403, message='Work item was closed.')