Skip to content

Commit

Permalink
Proper fix for requirement leaks across sibling classes (closes #928)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmalloc committed Mar 13, 2019
1 parent 14e5af3 commit 84467ac
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 71 deletions.
32 changes: 25 additions & 7 deletions slash/core/requirements.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ..utils.python import resolve_underlying_function

_SLASH_REQUIRES_KEY_NAME = '__slash_requirements__'


Expand All @@ -12,18 +14,34 @@ def requires(req, message=None):
else:
assert message is None, 'Cannot specify message when passing Requirement objects to slash.requires'

def decorator(func):
reqs = getattr(func, _SLASH_REQUIRES_KEY_NAME, None)
if reqs is None:
reqs = []
setattr(func, _SLASH_REQUIRES_KEY_NAME, reqs)
def decorator(func_or_class):
reqs = _get_requirements_list(func_or_class)
reqs.append(req)
return func
return func_or_class
return decorator

def _get_requirements_list(thing, create=True):

thing = resolve_underlying_function(thing)
existing = getattr(thing, _SLASH_REQUIRES_KEY_NAME, None)

key = id(thing)


if existing is None or key != existing[0]:
new_reqs = (key, [] if existing is None else existing[1][:])
if create:
setattr(thing, _SLASH_REQUIRES_KEY_NAME, new_reqs)
assert thing.__slash_requirements__ is new_reqs
returned = new_reqs[1]
else:
returned = existing[1]

return returned


def get_requirements(test):
return list(getattr(test, _SLASH_REQUIRES_KEY_NAME, []))
return list(_get_requirements_list(test, create=False))


class Requirement(object):
Expand Down
12 changes: 12 additions & 0 deletions slash/utils/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,15 @@ def call_all_raise_first(_funcs, *args, **kwargs):
exc_info = sys.exc_info()
if exc_info is not None:
reraise(*exc_info)


def resolve_underlying_function(thing):
"""Gets the underlying (real) function for functions, wrapped functions, methods, etc.
Returns the same object for things that are not functions
"""
while True:
wrapped = getattr(thing, "__func__", None) or getattr(thing, "__wrapped__", None) or getattr(thing, "__wraps__", None)
if wrapped is None:
break
thing = wrapped
return thing
59 changes: 58 additions & 1 deletion tests/test_python_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# pylint: disable=redefined-outer-name
import pytest
from slash._compat import PY2

from slash.utils.python import call_all_raise_first
if PY2:
from slash.utils.python import wraps
else:
from functools import wraps

from slash.utils.python import call_all_raise_first, resolve_underlying_function


def test_call_all_raise_first(funcs):
Expand Down Expand Up @@ -35,3 +41,54 @@ class CustomException(Exception):
return self.exc_type

return [Func() for _ in range(10)]


@pytest.mark.parametrize('class_method', [True, False])
def test_resolve_underlying_function_method(class_method):
if class_method:
decorator = classmethod
else:
decorator = lambda f: f

class Blap(object):

@decorator
def method(self):
pass

resolved = resolve_underlying_function(Blap.method)
assert resolved is resolve_underlying_function(Blap.method) # stable
assert not hasattr(resolved, '__func__')
assert resolved.__name__ == 'method'


@pytest.mark.parametrize('thing', [object(), object, None, 2, "string"])
def test_resolve_underlying_function_method_no_op(thing):
assert resolve_underlying_function(thing) is thing


def _example_decorator(func):
@wraps(func)
def new_func():
pass

return new_func

def test_resolve_underlying_decorator_regular_func():

def orig():
pass
decorated = _example_decorator(orig)
assert resolve_underlying_function(decorated) is orig

def test_resolve_underlying_decorator_method():

class Blap(object):

def orig(self):
pass

decorated = _example_decorator(orig)

assert resolve_underlying_function(Blap.decorated) is resolve_underlying_function(Blap.orig)
assert resolve_underlying_function(Blap.decorated).__name__ == 'orig'
Loading

0 comments on commit 84467ac

Please sign in to comment.