From 567cbc05d9788b804e8198f2c4729f07d60de594 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 20 Jun 2024 10:38:07 +0000 Subject: [PATCH] set previous values --- codeforlife/models/signals/pre_save.py | 47 ++++++++++++++++++++------ 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/codeforlife/models/signals/pre_save.py b/codeforlife/models/signals/pre_save.py index 44aeb104..398c2062 100644 --- a/codeforlife/models/signals/pre_save.py +++ b/codeforlife/models/signals/pre_save.py @@ -25,6 +25,24 @@ def adding(instance: _.AnyModel): return instance._state.adding +def _generate_get_previous_value( + instance: _.AnyModel, +) -> t.Callable[[str], t.Any]: + if adding(instance): + # pylint: disable-next=unused-argument + def get_previous_value(field: str): + return None + + else: + objects = instance.__class__.objects # type: ignore[attr-defined] + previous_instance = objects.get(pk=instance.pk) + + def get_previous_value(field: str): + return getattr(previous_instance, field) + + return get_previous_value + + def check_previous_values( instance: _.AnyModel, predicates: t.Dict[str, t.Callable[[t.Any, t.Any], bool]], @@ -42,17 +60,7 @@ def check_previous_values( """ # pylint: enable=line-too-long - if adding(instance): - # pylint: disable-next=unused-argument - def get_previous_value(field: str): - return None - - else: - objects = instance.__class__.objects # type: ignore[attr-defined] - previous_instance = objects.get(pk=instance.pk) - - def get_previous_value(field: str): - return getattr(previous_instance, field) + get_previous_value = _generate_get_previous_value(instance) return all( predicate(get_previous_value(field), getattr(instance, field)) @@ -60,6 +68,23 @@ def get_previous_value(field: str): ) +def set_previous_values(instance: _.AnyModel, fields: t.Set[str]): + # pylint: disable=line-too-long + """Set the previous value of the specified fields. All fields are set on the + instance with the naming convention: "previous_{field}". + + Args: + instance: The current instance. + fields: The fields to get the previous value for. + """ + # pylint: enable=line-too-long + + get_previous_value = _generate_get_previous_value(instance) + + for field in fields: + setattr(instance, f"previous_{field}", get_previous_value(field)) + + def previous_values_are_unequal(instance: _.AnyModel, fields: t.Set[str]): # pylint: disable=line-too-long """Check if all the previous values are not equal to the current values. If