From 4e0f081f914258859dc1759236aebe17816b22c6 Mon Sep 17 00:00:00 2001 From: coxm Date: Wed, 14 Dec 2022 18:02:32 +0000 Subject: [PATCH] Add optional side_effect_meta parameter Allow side effects to receive an additional `side_effect_meta` param which contains the side effect label and the original function's return value. This extra param provides a canonical place to store all current and future side effect metadata, and has a more restricted API than `return_value`, making it easier to reason about. 1. The `side_effect_meta` param _must_ be an explicit param (i.e. it cannot be accessed via `**kwargs` like `return_value`). 2. The parameter _must_ be a keyword-only parameter; this avoids issues with parameter ordering, as well as clashes with `return_value`. 3. The parameter is designed to co-exist with `return_value`, giving consumers a window to adopt the new API. --- side_effects/checks.py | 10 +++-- side_effects/registry.py | 60 ++++++++++++++++++++------- tests/test_registry.py | 90 ++++++++++++++++++++++++++++++---------- 3 files changed, 121 insertions(+), 39 deletions(-) diff --git a/side_effects/checks.py b/side_effects/checks.py index 3254384..58a8d71 100644 --- a/side_effects/checks.py +++ b/side_effects/checks.py @@ -30,11 +30,15 @@ def trim_signature(func: Callable) -> inspect.Signature: sig = inspect.signature(func) # remove return_value from the signature params as it's dynamic # and may/ may not exist depending on the usage. - params = [sig.parameters[p] for p in sig.parameters if p != "return_value"] + params = [ + param + for param_name, param in sig.parameters.items() + if param_name not in ("return_value", "side_effect_meta") + ] return sig.replace(parameters=params, return_annotation=sig.return_annotation) - signatures = [trim_signature(func) for func in registry._registry[label]] - return len(set(signatures)) + signatures = {trim_signature(func) for func in registry._registry[label]} + return len(signatures) @register() diff --git a/side_effects/registry.py b/side_effects/registry.py index f17a0e8..31ffe86 100644 --- a/side_effects/registry.py +++ b/side_effects/registry.py @@ -5,8 +5,9 @@ import logging import threading from collections import defaultdict +from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, Generic, List, TypeVar from django.db import transaction from django.dispatch import Signal @@ -46,6 +47,23 @@ def __init__(self, func: Callable): pass +ReturnValue = TypeVar("ReturnValue") + + +@dataclass +class SideEffectMeta(Generic[ReturnValue]): + """ + Metadata available to all side effect handlers. + + Whenever a side effect is triggered, a `SideEffectMeta` object is created with the + side effect label and the return value of the triggering function. To access it in a + handler, add a keyword-only parameter called `side_effect_meta`. + """ + + label: str + return_value: ReturnValue + + class Registry(defaultdict): """ Registry of side effect functions. @@ -117,12 +135,12 @@ def enable(self) -> None: self._suppress = False def _run_side_effects( - self, label: str, *args: Any, return_value: Any | None = None, **kwargs: Any + self, *args: Any, meta: SideEffectMeta, **kwargs: Any ) -> None: if settings.TEST_MODE_FAIL: - raise SideEffectsTestFailure(label) - for func in self[label]: - _run_func(func, *args, return_value=return_value, **kwargs) + raise SideEffectsTestFailure(meta.label) + for func in self[meta.label]: + _run_func(func, *args, meta=meta, **kwargs) def run_side_effects( self, label: str, *args: Any, return_value: Any | None = None, **kwargs: Any @@ -140,15 +158,19 @@ def run_side_effects( functions fail hard and early. """ + meta = SideEffectMeta(label=label, return_value=return_value) # TODO: this is all becoming over-complex - need to simplify this - self.try_bind_all(label, *args, return_value=return_value, **kwargs) + self.try_bind_all(*args, meta=meta, **kwargs) if self.is_suppressed: self.suppressed_side_effect.send(Registry, label=label) else: - self._run_side_effects(label, *args, return_value=return_value, **kwargs) + self._run_side_effects(*args, meta=meta, **kwargs) def try_bind_all( - self, label: str, *args: Any, return_value: Any | None = None, **kwargs: Any + self, + *args: Any, + meta: SideEffectMeta, + **kwargs: Any, ) -> None: """ Test all receivers for signature compatibility. @@ -156,9 +178,10 @@ def try_bind_all( Raise SignatureMismatch if any function does not match. """ - for func in self[label]: + for func in self[meta.label]: if not ( - try_bind(func, *args, return_value=return_value, **kwargs) + try_bind(func, *args, side_effect_meta=meta, **kwargs) + or try_bind(func, *args, return_value=meta.return_value, **kwargs) or try_bind(func, *args, **kwargs) ): raise SignatureMismatch(func) @@ -222,13 +245,20 @@ def run_side_effects_on_commit( ) -def _run_func( - func: Callable, *args: Any, return_value: Any | None = None, **kwargs: Any -) -> None: +def _run_func(func: Callable, *args: Any, meta: SideEffectMeta, **kwargs: Any) -> None: """Run a single side-effect function and handle errors.""" try: - if try_bind(func, *args, return_value=return_value, **kwargs): - func(*args, return_value=return_value, **kwargs) + # The current return_value logic will pass a return_value to any function + # accepting arbitrary kwargs. Therefore the side_effect_meta check must come + # first. Further, we can't assume that a handler accepting arbitrary **kwargs + # will not expect them to include the return_value. Instead, insist on an + # explicit side_effect_meta parameter; and require it to be keyword-only to + # avoid parameter ordering issues. + meta_param = inspect.signature(func).parameters.get("side_effect_meta") + if meta_param is not None and meta_param.kind == meta_param.KEYWORD_ONLY: + func(*args, side_effect_meta=meta, **kwargs) + elif try_bind(func, *args, return_value=meta.return_value, **kwargs): + func(*args, return_value=meta.return_value, **kwargs) elif try_bind(func, *args, **kwargs): func(*args, **kwargs) else: diff --git a/tests/test_registry.py b/tests/test_registry.py index 5f5c058..81dd83a 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Any from unittest import mock from django.test import TestCase @@ -146,7 +149,8 @@ def test__run_func__no_return_value(self): def test_func(): pass - registry._run_func(test_func, return_value=None) + meta = registry.SideEffectMeta(label="foo", return_value=None) + registry._run_func(test_func, meta=meta) def test__run_func__with_return_value(self): """Test the _run_func function passes through the return_value if required.""" @@ -154,10 +158,8 @@ def test__run_func__with_return_value(self): def test_func(**kwargs): assert "return_value" in kwargs - # return_value not passed through, so fails - registry._run_func(test_func) - # self.assertRaises(KeyError, registry._run_func, test_func) - registry._run_func(test_func, return_value=None) + meta = registry.SideEffectMeta(label="foo", return_value=None) + registry._run_func(test_func, meta=meta) def test__run_func__aborts_on_error(self): """Test the _run_func function handles ABORT_ON_ERROR correctly.""" @@ -165,15 +167,17 @@ def test__run_func__aborts_on_error(self): def test_func(): raise Exception("Pah") + meta = registry.SideEffectMeta(label="foo", return_value=None) + # error is logged, but not raised with mock.patch.object(settings, "ABORT_ON_ERROR", False): self.assertFalse(settings.ABORT_ON_ERROR) - registry._run_func(test_func, return_value=None) + registry._run_func(test_func, meta=meta) # error is raised with mock.patch.object(settings, "ABORT_ON_ERROR", True): self.assertTrue(settings.ABORT_ON_ERROR) - self.assertRaises(Exception, registry._run_func, test_func) + self.assertRaises(Exception, registry._run_func, test_func, meta=meta) def test__run_func__signature_mismatch(self): """Test the _run_func function always raises SignatureMismatch.""" @@ -181,9 +185,14 @@ def test__run_func__signature_mismatch(self): def test_func(): raise Exception("Pah") + meta = registry.SideEffectMeta(label="foo", return_value=None) with mock.patch.object(settings, "ABORT_ON_ERROR", False): self.assertRaises( - registry.SignatureMismatch, registry._run_func, test_func, 1 + registry.SignatureMismatch, + registry._run_func, + test_func, + 1, + meta=meta, ) @@ -223,27 +232,63 @@ def test_func(): self.assertEqual(r.by_label_contains("foo"), {"foo": [test_func]}) self.assertEqual(r.by_label_contains("food"), {}) - @mock.patch("side_effects.registry._run_func") - def test__run_side_effects__no_return_value(self, mock_run): - """Test return_value is not passed""" + def test__run_side_effects__with_side_effect_meta(self) -> None: + """Test the meta object is passed if the function requires it explictly.""" + actual_call_a = mock.Mock() + actual_call_b = mock.Mock() + actual_call_c = mock.Mock() + r = registry.Registry() + + def handler_a(*, side_effect_meta: registry.SideEffectMeta) -> None: + actual_call_a(side_effect_meta=side_effect_meta) + + def handler_b(*args: Any, side_effect_meta: registry.SideEffectMeta) -> None: + actual_call_b(*args, side_effect_meta=side_effect_meta) + + def handler_c( + *, side_effect_meta: registry.SideEffectMeta, **kwargs: Any + ) -> None: + actual_call_c(side_effect_meta=side_effect_meta, **kwargs) + + r.add("a", handler_a) + r.add("b", handler_b) + r.add("c", handler_c) - def no_return_value(*args, **kwargz): - assert "return_value" not in kwargz + meta_a = registry.SideEffectMeta(label="a", return_value=None) + r._run_side_effects(meta=meta_a) + actual_call_a.assert_called_once_with(side_effect_meta=meta_a) + meta_b = registry.SideEffectMeta(label="b", return_value=None) + r._run_side_effects(1, 2, 3, meta=meta_b) + actual_call_b.assert_called_once_with(1, 2, 3, side_effect_meta=meta_b) + + meta_c = registry.SideEffectMeta(label="c", return_value=None) + r._run_side_effects(meta=meta_c, x=1, y=2) + actual_call_c.assert_called_once_with(side_effect_meta=meta_c, x=1, y=2) + + def test__run_side_effects__side_effect_meta_must_be_keyword_only(self) -> None: + """Test that the meta object is not passed if not a keyword-only param.""" + meta = registry.SideEffectMeta(label="foo", return_value=None) r = registry.Registry() - r.add("foo", no_return_value) - r._run_side_effects("foo") - r._run_side_effects("foo", return_value=None) + + def handler(side_effect_meta: registry.SideEffectMeta, *args, **kwargs: Any) -> None: + pass + + r.add(meta.label, handler) + self.assertRaises(registry.SignatureMismatch, r._run_side_effects, meta=meta) def test__run_side_effects__with_return_value(self): - """Test return_value is passed""" + """Test return_value is passed if the function has **kwargs.""" + actual_call = mock.Mock() r = registry.Registry() def has_return_value(*args, **kwargs): - assert "return_value" in kwargs + actual_call(*args, **kwargs) r.add("foo", has_return_value) - r._run_side_effects("foo", return_value=None) + meta = registry.SideEffectMeta(label="foo", return_value=None) + r._run_side_effects(meta=meta) + actual_call.assert_called_once_with(return_value=None) def test_try_bind_all(self): def foo1(return_value): @@ -267,5 +312,8 @@ def foo5(arg1, **kwargs): r.add("foo", foo3) r.add("foo", foo4) r.add("foo", foo5) - r.try_bind_all("foo", 1) - self.assertRaises(registry.SignatureMismatch, r.try_bind_all, "foo", 1, 2) + meta = registry.SideEffectMeta(label="foo", return_value=None) + r.try_bind_all(1, meta=meta) + self.assertRaises( + registry.SignatureMismatch, r.try_bind_all, 1, 2, meta=meta + )