Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional side_effect_meta parameter #31

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions side_effects/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ def trim_signature(func: Callable) -> inspect.Signature:
sig = inspect.signature(func)
# remove return_value from the signature params as it's dynamic
# and may/ may not exist depending on the usage.
params = [sig.parameters[p] for p in sig.parameters if p != "return_value"]
params = [
param
for param_name, param in sig.parameters.items()
if param_name not in ("return_value", "side_effect_meta")
]
return sig.replace(parameters=params, return_annotation=sig.return_annotation)

signatures = [trim_signature(func) for func in registry._registry[label]]
return len(set(signatures))
signatures = {trim_signature(func) for func in registry._registry[label]}
return len(signatures)


@register()
Expand Down
60 changes: 45 additions & 15 deletions side_effects/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import logging
import threading
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, Generic, List, TypeVar

from django.db import transaction
from django.dispatch import Signal
Expand Down Expand Up @@ -46,6 +47,23 @@ def __init__(self, func: Callable):
pass


ReturnValue = TypeVar("ReturnValue")


@dataclass
class SideEffectMeta(Generic[ReturnValue]):
"""
Metadata available to all side effect handlers.

Whenever a side effect is triggered, a `SideEffectMeta` object is created with the
side effect label and the return value of the triggering function. To access it in a
handler, add a keyword-only parameter called `side_effect_meta`.
"""

label: str
return_value: ReturnValue


class Registry(defaultdict):
"""
Registry of side effect functions.
Expand Down Expand Up @@ -117,12 +135,12 @@ def enable(self) -> None:
self._suppress = False

def _run_side_effects(
self, label: str, *args: Any, return_value: Any | None = None, **kwargs: Any
self, *args: Any, meta: SideEffectMeta, **kwargs: Any
) -> None:
if settings.TEST_MODE_FAIL:
raise SideEffectsTestFailure(label)
for func in self[label]:
_run_func(func, *args, return_value=return_value, **kwargs)
raise SideEffectsTestFailure(meta.label)
for func in self[meta.label]:
_run_func(func, *args, meta=meta, **kwargs)

def run_side_effects(
self, label: str, *args: Any, return_value: Any | None = None, **kwargs: Any
Expand All @@ -140,25 +158,30 @@ def run_side_effects(
functions fail hard and early.

"""
meta = SideEffectMeta(label=label, return_value=return_value)
# TODO: this is all becoming over-complex - need to simplify this
self.try_bind_all(label, *args, return_value=return_value, **kwargs)
self.try_bind_all(*args, meta=meta, **kwargs)
if self.is_suppressed:
self.suppressed_side_effect.send(Registry, label=label)
else:
self._run_side_effects(label, *args, return_value=return_value, **kwargs)
self._run_side_effects(*args, meta=meta, **kwargs)

def try_bind_all(
self, label: str, *args: Any, return_value: Any | None = None, **kwargs: Any
self,
*args: Any,
meta: SideEffectMeta,
**kwargs: Any,
) -> None:
"""
Test all receivers for signature compatibility.

Raise SignatureMismatch if any function does not match.

"""
for func in self[label]:
for func in self[meta.label]:
if not (
try_bind(func, *args, return_value=return_value, **kwargs)
try_bind(func, *args, side_effect_meta=meta, **kwargs)
or try_bind(func, *args, return_value=meta.return_value, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're temporarily adding complexity here, but if the return_value were to be removed in favour of side_effect_meta, we could simplify this signature checking logic considerably.

The return_value bind check could be removed immediately, but also the side_effect_meta check is simpler. Crucially, return_value requires us to check whether the signature can be bound, because it is injected into **kwargs when not explicitly present as a parameter, whereas the need for side_effect_meta can be reduced to checking for a particular parameter name & type in the signature.

This potentially means that signatures could be checked inside the decorators_ has_side_effectsandis_side_effect_of(oncerun_side_effects` was removed) - i.e. on initialisation, rather than at runtime, every time a side effect is triggered.

or try_bind(func, *args, **kwargs)
):
raise SignatureMismatch(func)
Expand Down Expand Up @@ -222,13 +245,20 @@ def run_side_effects_on_commit(
)


def _run_func(
func: Callable, *args: Any, return_value: Any | None = None, **kwargs: Any
) -> None:
def _run_func(func: Callable, *args: Any, meta: SideEffectMeta, **kwargs: Any) -> None:
"""Run a single side-effect function and handle errors."""
try:
if try_bind(func, *args, return_value=return_value, **kwargs):
func(*args, return_value=return_value, **kwargs)
# The current return_value logic will pass a return_value to any function
# accepting arbitrary kwargs. Therefore the side_effect_meta check must come
# first. Further, we can't assume that a handler accepting arbitrary **kwargs
# will not expect them to include the return_value. Instead, insist on an
# explicit side_effect_meta parameter; and require it to be keyword-only to
# avoid parameter ordering issues.
meta_param = inspect.signature(func).parameters.get("side_effect_meta")
if meta_param is not None and meta_param.kind == meta_param.KEYWORD_ONLY:
func(*args, side_effect_meta=meta, **kwargs)
elif try_bind(func, *args, return_value=meta.return_value, **kwargs):
func(*args, return_value=meta.return_value, **kwargs)
elif try_bind(func, *args, **kwargs):
func(*args, **kwargs)
else:
Expand Down
90 changes: 69 additions & 21 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

from typing import Any
from unittest import mock

from django.test import TestCase
Expand Down Expand Up @@ -146,44 +149,50 @@ def test__run_func__no_return_value(self):
def test_func():
pass

registry._run_func(test_func, return_value=None)
meta = registry.SideEffectMeta(label="foo", return_value=None)
registry._run_func(test_func, meta=meta)

def test__run_func__with_return_value(self):
"""Test the _run_func function passes through the return_value if required."""

def test_func(**kwargs):
assert "return_value" in kwargs

# return_value not passed through, so fails
registry._run_func(test_func)
# self.assertRaises(KeyError, registry._run_func, test_func)
registry._run_func(test_func, return_value=None)
meta = registry.SideEffectMeta(label="foo", return_value=None)
registry._run_func(test_func, meta=meta)

def test__run_func__aborts_on_error(self):
"""Test the _run_func function handles ABORT_ON_ERROR correctly."""

def test_func():
raise Exception("Pah")

meta = registry.SideEffectMeta(label="foo", return_value=None)

# error is logged, but not raised
with mock.patch.object(settings, "ABORT_ON_ERROR", False):
self.assertFalse(settings.ABORT_ON_ERROR)
registry._run_func(test_func, return_value=None)
registry._run_func(test_func, meta=meta)

# error is raised
with mock.patch.object(settings, "ABORT_ON_ERROR", True):
self.assertTrue(settings.ABORT_ON_ERROR)
self.assertRaises(Exception, registry._run_func, test_func)
self.assertRaises(Exception, registry._run_func, test_func, meta=meta)

def test__run_func__signature_mismatch(self):
"""Test the _run_func function always raises SignatureMismatch."""

def test_func():
raise Exception("Pah")

meta = registry.SideEffectMeta(label="foo", return_value=None)
with mock.patch.object(settings, "ABORT_ON_ERROR", False):
self.assertRaises(
registry.SignatureMismatch, registry._run_func, test_func, 1
registry.SignatureMismatch,
registry._run_func,
test_func,
1,
meta=meta,
)


Expand Down Expand Up @@ -223,27 +232,63 @@ def test_func():
self.assertEqual(r.by_label_contains("foo"), {"foo": [test_func]})
self.assertEqual(r.by_label_contains("food"), {})

@mock.patch("side_effects.registry._run_func")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been removed (not rewritten). It had hidden failures, and was incorrect: no_return_value does actually receive a return_value in kwargz.

(The test__run_side_effects__no_return_value test failed silently because _run_func was mocked, but even after removing that, the assertion assert "return_value" not in kwargz actually failed. This failure would not be detected, though, because settings.ABORT_ON_FAILURE is off by default.)

def test__run_side_effects__no_return_value(self, mock_run):
"""Test return_value is not passed"""
def test__run_side_effects__with_side_effect_meta(self) -> None:
"""Test the meta object is passed if the function requires it explictly."""
actual_call_a = mock.Mock()
actual_call_b = mock.Mock()
actual_call_c = mock.Mock()
r = registry.Registry()

def handler_a(*, side_effect_meta: registry.SideEffectMeta) -> None:
actual_call_a(side_effect_meta=side_effect_meta)

def handler_b(*args: Any, side_effect_meta: registry.SideEffectMeta) -> None:
actual_call_b(*args, side_effect_meta=side_effect_meta)

def handler_c(
*, side_effect_meta: registry.SideEffectMeta, **kwargs: Any
) -> None:
actual_call_c(side_effect_meta=side_effect_meta, **kwargs)

r.add("a", handler_a)
r.add("b", handler_b)
r.add("c", handler_c)

def no_return_value(*args, **kwargz):
assert "return_value" not in kwargz
meta_a = registry.SideEffectMeta(label="a", return_value=None)
r._run_side_effects(meta=meta_a)
actual_call_a.assert_called_once_with(side_effect_meta=meta_a)

meta_b = registry.SideEffectMeta(label="b", return_value=None)
r._run_side_effects(1, 2, 3, meta=meta_b)
actual_call_b.assert_called_once_with(1, 2, 3, side_effect_meta=meta_b)

meta_c = registry.SideEffectMeta(label="c", return_value=None)
r._run_side_effects(meta=meta_c, x=1, y=2)
actual_call_c.assert_called_once_with(side_effect_meta=meta_c, x=1, y=2)

def test__run_side_effects__side_effect_meta_must_be_keyword_only(self) -> None:
"""Test that the meta object is not passed if not a keyword-only param."""
meta = registry.SideEffectMeta(label="foo", return_value=None)
r = registry.Registry()
r.add("foo", no_return_value)
r._run_side_effects("foo")
r._run_side_effects("foo", return_value=None)

def handler(side_effect_meta: registry.SideEffectMeta, *args, **kwargs: Any) -> None:
pass

r.add(meta.label, handler)
self.assertRaises(registry.SignatureMismatch, r._run_side_effects, meta=meta)

def test__run_side_effects__with_return_value(self):
"""Test return_value is passed"""
"""Test return_value is passed if the function has **kwargs."""
actual_call = mock.Mock()
r = registry.Registry()

def has_return_value(*args, **kwargs):
assert "return_value" in kwargs
actual_call(*args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a mock here to monitor the parameters means that the assertion can't be accidentally swallowed by the try/except in _run_func.


r.add("foo", has_return_value)
r._run_side_effects("foo", return_value=None)
meta = registry.SideEffectMeta(label="foo", return_value=None)
r._run_side_effects(meta=meta)
actual_call.assert_called_once_with(return_value=None)

def test_try_bind_all(self):
def foo1(return_value):
Expand All @@ -267,5 +312,8 @@ def foo5(arg1, **kwargs):
r.add("foo", foo3)
r.add("foo", foo4)
r.add("foo", foo5)
r.try_bind_all("foo", 1)
self.assertRaises(registry.SignatureMismatch, r.try_bind_all, "foo", 1, 2)
meta = registry.SideEffectMeta(label="foo", return_value=None)
r.try_bind_all(1, meta=meta)
self.assertRaises(
registry.SignatureMismatch, r.try_bind_all, 1, 2, meta=meta
)