diff --git a/codeforlife/models/signals/post_save.py b/codeforlife/models/signals/post_save.py index 37f5a988..e42e14ec 100644 --- a/codeforlife/models/signals/post_save.py +++ b/codeforlife/models/signals/post_save.py @@ -11,6 +11,8 @@ from . import general as _ from .pre_save import PREVIOUS_VALUE_KEY +FieldValue = t.TypeVar("FieldValue") + def has_previous_values(instance: _.AnyModel, fields: t.Dict[str, t.Type]): # pylint: disable=line-too-long @@ -36,3 +38,27 @@ def has_previous_values(instance: _.AnyModel, fields: t.Dict[str, t.Type]): return False return True + + +def get_previous_value( + instance: _.AnyModel, field: str, cls: t.Type[FieldValue] +): + # pylint: disable=line-too-long + """Get a previous value from the instance and assert the value is of the + expected type. + + Args: + instance: The current instance. + field: The field to get the previous value for. + cls: The expected type of the value. + + Returns: + The previous value of the field. + """ + # pylint: enable=line-too-long + + previous_value = getattr(instance, PREVIOUS_VALUE_KEY.format(field=field)) + + assert isinstance(previous_value, cls) + + return previous_value