Skip to content

Commit

Permalink
set previous values
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jun 20, 2024
1 parent 3213241 commit 567cbc0
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions codeforlife/models/signals/pre_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ def adding(instance: _.AnyModel):
return instance._state.adding


def _generate_get_previous_value(
instance: _.AnyModel,
) -> t.Callable[[str], t.Any]:
if adding(instance):
# pylint: disable-next=unused-argument
def get_previous_value(field: str):
return None

else:
objects = instance.__class__.objects # type: ignore[attr-defined]
previous_instance = objects.get(pk=instance.pk)

def get_previous_value(field: str):
return getattr(previous_instance, field)

return get_previous_value


def check_previous_values(
instance: _.AnyModel,
predicates: t.Dict[str, t.Callable[[t.Any, t.Any], bool]],
Expand All @@ -42,24 +60,31 @@ def check_previous_values(
"""
# pylint: enable=line-too-long

if adding(instance):
# pylint: disable-next=unused-argument
def get_previous_value(field: str):
return None

else:
objects = instance.__class__.objects # type: ignore[attr-defined]
previous_instance = objects.get(pk=instance.pk)

def get_previous_value(field: str):
return getattr(previous_instance, field)
get_previous_value = _generate_get_previous_value(instance)

return all(
predicate(get_previous_value(field), getattr(instance, field))
for field, predicate in predicates.items()
)


def set_previous_values(instance: _.AnyModel, fields: t.Set[str]):
# pylint: disable=line-too-long
"""Set the previous value of the specified fields. All fields are set on the
instance with the naming convention: "previous_{field}".
Args:
instance: The current instance.
fields: The fields to get the previous value for.
"""
# pylint: enable=line-too-long

get_previous_value = _generate_get_previous_value(instance)

for field in fields:
setattr(instance, f"previous_{field}", get_previous_value(field))


def previous_values_are_unequal(instance: _.AnyModel, fields: t.Set[str]):
# pylint: disable=line-too-long
"""Check if all the previous values are not equal to the current values. If
Expand Down

0 comments on commit 567cbc0

Please sign in to comment.