Skip to content

Commit

Permalink
server: Added change detection to Base views
Browse files Browse the repository at this point in the history
 * Refactored ModelDataDiffer
 * Use distinct ModelDataDiffer instance for views
   (super class variables are initialized only once)
 * Followed up Base views changes in Work view
  • Loading branch information
andras-tim committed Nov 26, 2015
1 parent b0a502b commit ea13824
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 25 deletions.
54 changes: 40 additions & 14 deletions server/app/modules/view_helper_for_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from enum import Enum, unique
from flask import request
from marshmallow import Serializer
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -133,46 +134,71 @@ def get_validated_request(deserializer: Serializer) -> (dict, list, None):
return data


@unique
class ModelDataStates(Enum):
original = 'original'
populated = 'populated'


class ModelDataDiffer:
"""
Tool for checking changes of the specified fields
"""
def __init__(self):
self.__fields = set()
self.__nested_fields = {}
self.__original_values = {}

def save_state(self, model: db.Model):
self.__initial_enumerate_nested_fields(model)
self.__states = {}
for state in ModelDataStates:
self.__states[state.value] = {}

def save_state(self, model: (db.Model, None), state: ModelDataStates):
state_object = self.__states[state] = {}
if model is None:
return

self.__initial_enumerate_fields(model)
for field_name in self.__fields:
state_object[field_name] = self.__get_field_value(model, field_name)

self.__original_values = {}
for column in model.__table__.columns:
self.__original_values[column.name] = self.__get_field_value(model, column.name)
def get_state(self, state: ModelDataStates) -> dict:
return self.__states[state]

def get_diff(self, model: db.Model) -> dict:
def diff_states(self) -> dict:
diff = {}
original_values = self.__states[ModelDataStates.original]
populated_values = self.__states[ModelDataStates.populated]

for column in model.__table__.columns:
original = self.__original_values[column.name]
current = self.__get_field_value(model, column.name)
for field_name in self.__fields:
original = original_values[field_name]
current = populated_values[field_name]
if original != current:
diff[column.name] = {'original': original, 'current': current}
diff[field_name] = {'original': original, 'current': current}

return diff

def __initial_enumerate_nested_fields(self, model: db.Model):
if self.__nested_fields:
def __initial_enumerate_fields(self, model: db.Model):
if self.__fields:
return

for field in model.__table__.columns:
self.__fields.add(field.name)

relationships = inspect(model.__class__).relationships
for relationship in relationships:
nested_field_name = relationship.key
for local_field, remote_field in relationship.local_remote_pairs:
self.__nested_fields[local_field.name] = [nested_field_name, remote_field.name]
if local_field.foreign_keys:
self.__nested_fields[local_field.name] = [nested_field_name, remote_field.name]

def __get_field_value(self, model: db.Model, name: str):
if name not in self.__nested_fields.keys():
return getattr(model, name)

nested_field_name, remote_field_name = self.__nested_fields[name]

nested_field = getattr(model, nested_field_name)
if nested_field is None:
return None

return getattr(nested_field, remote_field_name)
32 changes: 25 additions & 7 deletions server/app/views/base_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy.orm import Query
from sqlalchemy.exc import IntegrityError

from app.modules.view_helper_for_models import PopulateModelOnSubmit, ModelDataDiffer, SqlErrorParser
from app.modules.view_helper_for_models import PopulateModelOnSubmit, ModelDataDiffer, SqlErrorParser, ModelDataStates
from app.server import db
from app.views.common import commit_with_error_handling, commit_and_rollback_on_error

Expand All @@ -13,7 +13,10 @@ class _BaseModelResource(Resource):
_parent_model = None
_serializer = None
_deserializer = None
__differ = ModelDataDiffer()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__differ = ModelDataDiffer()

@property
def _query(self) -> Query:
Expand All @@ -33,11 +36,19 @@ def _serialize_many(self, items) -> list:
def _initialize_parent_item(self, parent_id: int) -> 'item':
return _initialize_parent_item(self._parent_model, parent_id)

def _save_original_before_populate(self, item):
self.__differ.save_state(item)
def _save_original_state(self, item=None):
self.__differ.save_state(item, ModelDataStates.original)

def _save_populated_state(self, item=None):
self.__differ.save_state(item, ModelDataStates.populated)

def _get_population_changes(self) -> dict:
return self.__differ.diff_states()

def _get_populate_diff(self, item) -> dict:
return self.__differ.get_diff(item)
def _create_populated_changelog(self):
original = self.__differ.get_state(ModelDataStates.original)
populated = self.__differ.get_state(ModelDataStates.populated)
print(original, populated)


class BaseListView(_BaseModelResource):
Expand Down Expand Up @@ -81,10 +92,12 @@ def _post_populate(self, **set_fields) -> 'item':
"""
Populate a new object
"""
self._save_original_state()
item = self._model()
for name, value in set_fields.items():
setattr(item, name, value)
self._populate_item(item)
self._save_populated_state(item)
return item

def _post_commit(self, item) -> 'RPC response':
Expand Down Expand Up @@ -167,7 +180,9 @@ def _put_populate(self, **filter) -> 'item':
Populate change object
"""
item = self._get_item_by_filter(**filter)
self._save_original_state(item)
self._populate_item(item)
self._save_populated_state(item)
return item

def _put_commit(self, item) -> 'RPC response':
Expand All @@ -181,7 +196,10 @@ def _delete_get_item(self, **filter) -> 'item':
"""
Getting object for delete
"""
return self._get_item_by_filter(**filter)
item = self._get_item_by_filter(**filter)
self._save_original_state(item)
self._save_populated_state()
return item

def _delete_commit(self, item) -> None:
"""
Expand Down
5 changes: 1 addition & 4 deletions server/app/views/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,8 @@ def get(self, id: int, item_id: int):
def put(self, id: int, item_id: int):
work = self._initialize_parent_item(id)

original_item = self._get_item_by_id(item_id)
self._save_original_before_populate(original_item)

modified_item = self._put_populate(work_id=id, id=item_id)
changed_fields = self._get_populate_diff(modified_item).keys()
changed_fields = self._get_population_changes().keys()

if self.__is_tried_to_change_closed(work, changed_fields):
abort(403, message='Work item was closed.')
Expand Down

0 comments on commit ea13824

Please sign in to comment.